diff --git a/CMakeLists.txt b/CMakeLists.txt index 23bbe829ac16180088bfa37df66e23f19b021ea3..030bd19b3fd2f561a847bbc4613e5d2030812a92 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,6 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: " message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: " "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") -find_package(Sphinx) if(NOT CMAKE_CROSSCOMPILING) find_package(CUDA QUIET) endif(NOT CMAKE_CROSSCOMPILING) @@ -226,5 +225,7 @@ if(WITH_PYTHON) endif() if(WITH_DOC) + find_package(Sphinx REQUIRED) + find_python_module(recommonmark REQUIRED) add_subdirectory(doc) endif() diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index 966c0bafd3742862e66d7ff36de86190507a6936..0332e39d14200da1c1af52675f0ccad2c07de405 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -56,6 +56,8 @@ ExternalProject_Add( GIT_TAG "v0.14" PREFIX ${MKLDNN_SOURCES_DIR} UPDATE_COMMAND "" + # Patch MKLDNN to compile with gcc 4.8, the related issue is in intel/mkl-dnn#237. + PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/mkldnn.hpp ${MKLDNN_SOURCES_DIR}/src/extern_mkldnn/include/mkldnn.hpp CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} CMAKE_ARGS -DMKLROOT=${MKLML_ROOT} diff --git a/doc/v2/build_and_install/build_from_source_cn.rst b/doc/v2/build_and_install/build_from_source_cn.rst index 115b92a33888abf1e1be400e1abbb58b632a2976..f846928954dd3a05e11054ce2ff2ff839fbefd4b 100644 --- a/doc/v2/build_and_install/build_from_source_cn.rst +++ b/doc/v2/build_and_install/build_from_source_cn.rst @@ -19,8 +19,9 @@ ---------------- PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安装编译依赖的步骤,可选的不同编译环境Docker镜像 -可以在 `这里 `_ 找到。或者 -参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。 +可以在 `这里 `_ 找到,您也可以 +在 `这里 `_ 找到 paddle_manylinux_devel +镜像的编译以及使用方法。或者参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。 如果您选择不使用Docker镜像,则需要在本机安装下面章节列出的 `编译依赖`_ 之后才能开始编译的步骤。 diff --git a/doc/v2/build_and_install/build_from_source_en.rst b/doc/v2/build_and_install/build_from_source_en.rst index 8fef9e7347e8d924026999bfda985381750c6b51..d1b5b88dff81d4c5cee3dd13a7dccbc333ab6a17 100644 --- a/doc/v2/build_and_install/build_from_source_en.rst +++ b/doc/v2/build_and_install/build_from_source_en.rst @@ -22,6 +22,8 @@ How To Build You need to use Docker to build PaddlePaddle to avoid installing dependencies by yourself. We have several pre-built Docker images `here `_ , +you can also find how to build and use paddle_manylinux_devel Docker image from +`here `_ Or you can build your own image from source as the optional step below: .. code-block:: bash diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index d373c48b1a75c5f75c7520b56f230bc2c146b174..a4eb6f706edab9479cbce436311eb96da8845646 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -192,6 +192,10 @@ class ExecutionContext { return op_.Attr(name); } + bool HasInput(const std::string& name) const { return op_.HasInputs(name); } + + bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); } + size_t InputSize(const std::string& name) const { return op_.Inputs(name).size(); } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 20ef7e09f630140c44774147aa727780df6333fa..95e807c0afa45bc4f4feb84d450b2d0584bc3b28 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -58,7 +58,8 @@ ParallelExecutor::ParallelExecutor( const std::unordered_set &bcast_vars, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope, const std::vector &local_scopes, bool allow_op_delay, - bool use_default_grad_scale, bool balance_parameter_opt_between_cards) + bool use_default_grad_scale, bool balance_parameter_opt_between_cards, + size_t num_trainers, size_t trainer_id) : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; @@ -80,7 +81,13 @@ ParallelExecutor::ParallelExecutor( // Bcast Parameters to all GPUs #ifdef PADDLE_WITH_CUDA - member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); + auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); + ncclUniqueId *nccl_id = nullptr; + if (nccl_id_var != nullptr) { + nccl_id = nccl_id_var->GetMutable(); + } + member_->nccl_ctxs_.reset(new platform::NCCLContextMap( + member_->places_, nccl_id, num_trainers, trainer_id)); #endif if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 && local_scopes.empty()) { // Is CUDA diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index b251fc91417a1c00e61e9c3c952460e6268d2819..9e279876cfeef20a1921f8bd1c27046a477b9f56 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -41,7 +41,8 @@ class ParallelExecutor { const std::string& loss_var_name, Scope* scope, const std::vector& local_scopes, bool allow_op_delay, bool use_default_grad_scale, - bool balance_parameter_opt_between_cards); + bool balance_parameter_opt_between_cards, + size_t num_trainers = 1, size_t trainer_id = 0); ~ParallelExecutor(); diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index de7becae4d25d48111fea8d2123bc85aef70230a..47929ef7490e5edb246625cb0b3ba507039df27a 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -1 +1,2 @@ -cc_library(dot SRCS dot.cc) +cc_library(analysis SRCS dot.cc node.cc node.h) +cc_test(test_node SRCS node_tester.cc DEPS analysis) diff --git a/paddle/fluid/inference/analysis/device.h b/paddle/fluid/inference/analysis/device.h new file mode 100644 index 0000000000000000000000000000000000000000..4423af842d28566fea419b8099efc3bda33787f4 --- /dev/null +++ b/paddle/fluid/inference/analysis/device.h @@ -0,0 +1,23 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +namespace paddle { +namespace inference { +namespace analysis { + +enum class Device { CPU, GPU }; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h index 3359987874f2d74d7e4646baa38790431c4b28fd..4bf1840fdda8508b52d7274a338c5b1c95baf354 100644 --- a/paddle/fluid/inference/analysis/dot.h +++ b/paddle/fluid/inference/analysis/dot.h @@ -21,6 +21,7 @@ #include #include +#include #include #include diff --git a/paddle/fluid/inference/analysis/dot_tester.cc b/paddle/fluid/inference/analysis/dot_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..56ceb9bd5d6f41a601d66f6124fb7b4099c9337e --- /dev/null +++ b/paddle/fluid/inference/analysis/dot_tester.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/dot.h" + +#include +#include +#include "paddle/fluid/inference/analysis/data_flow_graph.h" + +namespace paddle { +namespace inference { +namespace analysis { + +class DotTester : public ::testing::Test { + protected: + void SetUp() override { + std::vector attrs({{"title", "hello"}}); + dot.reset(new Dot(attrs)); + dot->AddNode("a", {Dot::Attr{"shape", "box"}, Dot::Attr("color", "blue")}); + dot->AddNode("b", {}); + dot->AddNode("c", {}); + dot->AddEdge("a", "b", {}); + dot->AddEdge("b", "c", {}); + dot->AddEdge("a", "c", {}); + } + + std::unique_ptr dot; +}; + +TEST_F(DotTester, Build) { + auto codes = dot->Build(); + // Output the DOT language code, the generated codes are too long to compare + // the string. + // + // The output is + // + // digraph G { + // title="hello" + // node_1 + // node_2 + // node_0[label="a" shape="box" color="blue"] + // node_0->node_1 + // node_1->node_2 + // node_0->node_2 + // } // end G + LOG(INFO) << '\n' << codes; +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..b2d06c5d63ff139186710cd963e07b4ba245f9f3 --- /dev/null +++ b/paddle/fluid/inference/analysis/helper.h @@ -0,0 +1,74 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace analysis { + +template +class iterator_range { + IteratorT begin_, end_; + + public: + template + explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {} + + iterator_range(const IteratorT &begin, const IteratorT &end) + : begin_(begin), end_(end) {} + + const IteratorT &begin() const { return begin_; } + const IteratorT &end() const { return end_; } +}; + +/* + * An registry helper class, with its records keeps the order they registers. + */ +template +class OrderedRegistry { + public: + T *Register(const std::string &name, T *x) { + PADDLE_ENFORCE(!dic_.count(name)); + dic_[name] = data_.size(); + data_.emplace_back(std::unique_ptr(x)); + return data_.back().get(); + } + + T *Lookup(const std::string &name) { + auto it = dic_.find(name); + if (it == dic_.end()) return nullptr; + return data_[it->second].get(); + } + + protected: + std::unordered_map dic_; + std::vector> data_; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle + +#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \ + \ + type__(const type__ &) = delete; \ + \ + void operator=(const type__ &) = delete; diff --git a/paddle/fluid/inference/analysis/node.cc b/paddle/fluid/inference/analysis/node.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe060526080b1ee01aa98f2ff06fb2191eddf9da --- /dev/null +++ b/paddle/fluid/inference/analysis/node.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/node.h" +#include "glog/logging.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace analysis { + +std::vector Value::dot_attrs() const { + return std::vector({Dot::Attr("style", "filled,rounded"), + Dot::Attr("shape", "box"), + Dot::Attr("fillcolor", "red")}); +} + +std::vector Function::dot_attrs() const { + return std::vector({Dot::Attr("style", "filled,rounded"), + Dot::Attr("shape", "diamond"), + Dot::Attr("fillcolor", "yellow")}); +} + +Node *NodeMap::Create(Node::Type type) { + switch (type) { + case Node::Type::kFunction: + nodes_.emplace_back(new Function); + break; + case Node::Type::kValue: + nodes_.emplace_back(new Value); + break; + default: + PADDLE_THROW("Not supported node type."); + } + nodes_.back()->id_ = size() - 1; + return nodes_.back().get(); +} + +Node *NodeMap::GetMutable(size_t id) { + PADDLE_ENFORCE_GT(size(), id); + return nodes_[id].get(); +} + +const Node &NodeMap::Get(size_t id) const { + PADDLE_ENFORCE_GT(size(), id); + return *nodes_[id].get(); +} + +void NodeMap::Delete(size_t id) { + PADDLE_ENFORCE_LT(id, size()); + nodes_[id]->SetDeleted(); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/node.h b/paddle/fluid/inference/analysis/node.h new file mode 100644 index 0000000000000000000000000000000000000000..59ba977798481684114d1189056be00bbb7777cf --- /dev/null +++ b/paddle/fluid/inference/analysis/node.h @@ -0,0 +1,234 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + * This file defines the Node class and its subclasses. A Node is the basis + * analysis element in a computation graph. + * There are basically two kinds of nodes, the function node and value node. + */ +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/inference/analysis/device.h" +#include "paddle/fluid/inference/analysis/dot.h" +#include "paddle/fluid/inference/analysis/helper.h" + +namespace paddle { +namespace inference { +namespace analysis { + +class NodeMap; + +/* + * Node Representation. + * + * This is a very important class for analysis. It is the base class of all + * nodes computed by a program that may be used as operands to other nodes. + * Node is the super class of other important classes such as Function and + * Value, some nodes can have a name. + */ +class Node { + public: + // Node type. NOTE the new node types should add here. + enum class Type { kNone = -1, kFunction, kValue, kFunctionBlock }; + + Node() = default; + + struct Attr; + + // Cast to a subclass type, Function for example. + template + Subclass &As() { + return *dynamic_cast(this); + } + + // Formatted representation of this Node. + virtual std::string repr() const { + return name() + "(" + std::to_string(id()) + ")"; + } + + // DOT node representation. One Node type can customize its own node + // representation. + virtual std::vector dot_attrs() const { + return std::vector({Dot::Attr("style", "filled")}); + } + + // Get an additional attribute and convert it to T data type. NOTE this will + // silently create a new attribute if not exists. + Attr &attr(const std::string &name) { return attrs_[name]; } + + int id() const { return id_; } + + bool deleted() const { return deleted_; } + void SetDeleted() { deleted_ = true; } + + void SetName(const std::string &name) { name_ = name; } + const std::string &name() const { return name_; } + + void SetType(Type type) { type_ = type; } + Type type() const { return type_; } + + void *extra_info() const { return extra_info_; } + void SetExtraInfo(void *extra_info) { extra_info_ = extra_info; } + + // Input links. + std::vector inlinks; + // Output links. + std::vector outlinks; + + // A helper class to maintain the status from Pass. + // TODO(superjomn) add a checker here to ensure the T is primary. + struct Attr { + // NOTE T should be a primary type or a struct combined by several primary + // types. + // NOTE the STL containers should not use here. + // Some usages + // Attr attr; + // T data; + // attr.data.assign((char*)data, sizeof(data)); + + bool &Bool() { return As(); } + float &Float() { return As(); } + int32_t &Int32() { return As(); } + int64_t &Int64() { return As(); } + + private: + template + T &As() { + // init storage in the first usage. + if (data_.empty()) { + VLOG(4) << "resize data to " << sizeof(T); + type_hash_ = typeid(T).hash_code(); + data_.resize(sizeof(T)); + } + PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(), "type not matched"); + PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error"); + return *reinterpret_cast(&data_[0]); + } + + private: + std::string data_; + size_t type_hash_{std::numeric_limits::max()}; + }; + + virtual ~Node() {} + + friend class NodeMap; + + PADDLE_DISALLOW_COPY_AND_ASSIGN(Node); + + protected: + // The id number not the name is a node's unique identifier in the computation + // graph. + int id_{-1}; + std::string name_; + Type type_{Type::kNone}; + // Mark this node is deleted by some pass. + bool deleted_{false}; + + void *extra_info_; + + mutable std::unordered_map attrs_; +}; + +class Function; +/* + * Value represents a value node, it has some attributes including dims, data + * type and so on. + */ +class Value : public Node { + public: + enum class DataType { kInt32, kInt64, kFloat32, kFloat64 }; + using Dims = std::vector; + + void SetDataType(DataType data_type) { data_type_ = data_type; } + DataType data_type() const { return data_type_; } + + void SetDims(const Dims &dims) { dims_ = dims; } + const Dims &dims() const { return dims_; } + + Device device() const { return device_; } + void SetDevice(Device device) { device_ = device; } + + std::vector dot_attrs() const override; + + PADDLE_DISALLOW_COPY_AND_ASSIGN(Value); + + protected: + Value() { SetType(Node::Type::kValue); } + friend class NodeMap; + + private: + DataType data_type_; + Dims dims_; + Device device_; +}; + +/* + * Function represents any kind of executable concepts that takes several Values + * as input, and outputs several Values. + */ +class Function : public Node { + public: + std::vector dot_attrs() const override; + + // Get the operator's type from Desc. + const std::string &func_type() const { return func_type_; } + // Set the operator's type. + void SetFuncType(const std::string &func_type) { func_type_ = func_type; } + + PADDLE_DISALLOW_COPY_AND_ASSIGN(Function); + + protected: + std::string func_type_; + Function() { SetType(Node::Type::kFunction); } + friend class NodeMap; +}; + +/* + * FunctionBlock is a Node that contains a sub-graph multiple Node. + */ +struct FunctionBlock : public Node { + std::string repr() const override { return "block-" + std::to_string(id()); } + std::vector subgraph; +}; + +class NodeMap { + public: + // Create a new node with type. + Node *Create(Node::Type type); + + // Get a node by its id. + Node *GetMutable(size_t id); + + const Node &Get(size_t id) const; + + void Delete(size_t id); + + const std::vector> &nodes() { return nodes_; } + + size_t size() const { return nodes_.size(); } + + private: + std::vector> nodes_; + std::unordered_map map_; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/node_tester.cc b/paddle/fluid/inference/analysis/node_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..47fea0fdff808c930ca73edb25f5b16fef397e9a --- /dev/null +++ b/paddle/fluid/inference/analysis/node_tester.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/inference/analysis/node.h" + +#include + +namespace paddle { +namespace inference { +namespace analysis { + +TEST(Node, Attr) { + // Node is an abstract class, use Value instead for they share the same Attr + // logic. + NodeMap nodes; + auto* node = nodes.Create(Node::Type::kValue); + node->attr("v0").Int32() = 2008; + ASSERT_EQ(node->attr("v0").Int32(), 2008); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h index de0375551e16ec53b90414c7446234fda98bf706..ce2b8161715a3fa2278ce950dbac82c6d0042bef 100644 --- a/paddle/fluid/inference/engine.h +++ b/paddle/fluid/inference/engine.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/framework.pb.h" namespace paddle { @@ -58,8 +59,8 @@ class EngineBase { struct Buffer { void* buffer{nullptr}; // buffer should be allocated only once. - int max_size; // buffer allocated space. - int size; // data size. + size_t max_size; // buffer allocated space. + size_t size; // data size. DeviceType device{DeviceType::UNK}; // tells which device this buffer is on. }; diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index 677b3e04af8e7f5662a15fb32e3b03f45d262733..b52d083f280e5e7713600a7b748dedd37aca0a1e 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,5 +1,4 @@ nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) - add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 5178c54c08400125d190078dac6c52d021f8488b..4fb4511d99179e4ea14cde66feb13bc9e114581a 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,4 +1,4 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES}) -nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc +nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 543784289cfc51048057e467d36fdd1f334eb903..6297051e5a30f1daa512d25d5aa3ab3b2f79f1d1 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -21,15 +21,18 @@ namespace tensorrt { class ReluOpConverter : public OpConverter { public: ReluOpConverter() {} - void operator()(const framework::OpDesc& op) override { + void operator()(const framework::proto::OpDesc& op) override { + // Here the two nullptr looks strange, that's because the + // framework::OpDesc's constructor is strange. + framework::OpDesc op_desc(op, nullptr, nullptr); LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " "type is Relu"; const nvinfer1::ITensor* input_tensor = - engine_->GetITensor(op.Input("X")[0]); + engine_->GetITensor(op_desc.Input("X")[0]); nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( engine_, Activation, *const_cast(input_tensor), nvinfer1::ActivationType::kRELU); - engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0)); + engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0)); } }; diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 431500b90e144e2a30fe705b72e93452f806ca65..209936c3bafb0d31546856dc36c1b48053a0634b 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -21,7 +21,7 @@ namespace tensorrt { class Conv2dOpConverter : public OpConverter { public: Conv2dOpConverter() {} - void operator()(const framework::OpDesc& op) override { + void operator()(const framework::proto::OpDesc& op) override { LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias"; } diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.cc b/paddle/fluid/inference/tensorrt/convert/io_converter.cc index 32e8631fde3f748669d2008b4a060455a37e154e..854f434d93e81237dc85c5df62debcf3b3824b78 100644 --- a/paddle/fluid/inference/tensorrt/convert/io_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/io_converter.cc @@ -23,26 +23,42 @@ namespace tensorrt { using platform::is_gpu_place; using platform::is_cpu_place; -class DefaultInputConverter : public EngineInputConverter { +class DefaultIOConverter : public EngineIOConverter { public: - DefaultInputConverter() {} + DefaultIOConverter() {} // NOTE out is GPU memory. virtual void operator()(const LoDTensor& in, void* out, size_t max_size) override { PADDLE_ENFORCE(out != nullptr); - PADDLE_ENFORCE_LE(in.memory_size(), max_size); + PADDLE_ENFORCE(stream_ != nullptr); const auto& place = in.place(); + size_t size = in.memory_size(); + PADDLE_ENFORCE_LE(size, max_size); if (is_cpu_place(place)) { - PADDLE_ENFORCE(stream_ != nullptr); - PADDLE_ENFORCE_EQ(0, - cudaMemcpyAsync(out, in.data(), in.memory_size(), - cudaMemcpyHostToDevice, *stream_)); - + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size, + cudaMemcpyHostToDevice, *stream_)); } else if (is_gpu_place(place)) { - PADDLE_ENFORCE_EQ(0, - cudaMemcpyAsync(out, in.data(), in.memory_size(), - cudaMemcpyHostToHost, *stream_)); - + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size, + cudaMemcpyDeviceToDevice, *stream_)); + } else { + PADDLE_THROW("Unknown device for converter"); + } + cudaStreamSynchronize(*stream_); + } + // NOTE in is GPU memory. + virtual void operator()(const void* in, LoDTensor* out, + size_t max_size) override { + PADDLE_ENFORCE(in != nullptr); + PADDLE_ENFORCE(stream_ != nullptr); + const auto& place = out->place(); + size_t size = out->memory_size(); + PADDLE_ENFORCE_LE(size, max_size); + if (is_cpu_place(place)) { + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size, + cudaMemcpyDeviceToHost, *stream_)); + } else if (is_gpu_place(place)) { + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size, + cudaMemcpyDeviceToDevice, *stream_)); } else { PADDLE_THROW("Unknown device for converter"); } @@ -50,7 +66,8 @@ class DefaultInputConverter : public EngineInputConverter { } }; -REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter); +// fluid LodTensor <-> tensorrt ITensor +REGISTER_TENSORRT_IO_CONVERTER(default, DefaultIOConverter); } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.h b/paddle/fluid/inference/tensorrt/convert/io_converter.h index 8972dae92be2c2d261a13c48d98e675f64e51d31..71c48e085d25d2bc6720d93735f661f9e3af7b40 100644 --- a/paddle/fluid/inference/tensorrt/convert/io_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/io_converter.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/utils/singleton.h" @@ -25,43 +26,57 @@ namespace tensorrt { using framework::LoDTensor; /* - * Convert Input from Fluid to an Engine. - * TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in - * most cases just need to copy the data. + * Convert Input from Fluid to TensorRT Engine. + * Convert Output from TensorRT Engine to Fluid. + * + * Note that TensorRT's ITensor follows row major, NCHW. Fluid is also row + * major, + * so in the default case just need to copy the data. */ -class EngineInputConverter { +class EngineIOConverter { public: - EngineInputConverter() {} + EngineIOConverter() {} virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {} + virtual void operator()(const void* in, LoDTensor* out, size_t max_size) {} void SetStream(cudaStream_t* stream) { stream_ = stream; } - static void Run(const std::string& in_op_type, const LoDTensor& in, void* out, - size_t max_size, cudaStream_t* stream) { + static void ConvertInput(const std::string& op_type, const LoDTensor& in, + void* out, size_t max_size, cudaStream_t* stream) { PADDLE_ENFORCE(stream != nullptr); - auto* converter = Registry::Lookup( - in_op_type, "default" /* default_type */); + auto* converter = Registry::Lookup( + op_type, "default" /* default_type */); PADDLE_ENFORCE_NOT_NULL(converter); converter->SetStream(stream); (*converter)(in, out, max_size); } - virtual ~EngineInputConverter() {} + static void ConvertOutput(const std::string& op_type, const void* in, + LoDTensor* out, size_t max_size, + cudaStream_t* stream) { + PADDLE_ENFORCE(stream != nullptr); + auto* converter = Registry::Lookup( + op_type, "default" /* default_type */); + PADDLE_ENFORCE_NOT_NULL(converter); + converter->SetStream(stream); + (*converter)(in, out, max_size); + } + + virtual ~EngineIOConverter() {} protected: cudaStream_t* stream_{nullptr}; }; +#define REGISTER_TENSORRT_IO_CONVERTER(op_type__, Converter__) \ + struct trt_io_##op_type__##_converter { \ + trt_io_##op_type__##_converter() { \ + Registry::Register(#op_type__); \ + } \ + }; \ + trt_io_##op_type__##_converter trt_io_##op_type__##_converter__; + } // namespace tensorrt } // namespace inference } // namespace paddle - -#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \ - struct trt_input_##in_op_type__##_converter { \ - trt_input_##in_op_type__##_converter() { \ - ::paddle::inference::Registry::Register< \ - Converter__>(#in_op_type__); \ - } \ - }; \ - trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__; diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc index f9834ab156c9dcc11f4e89075b7bf5457cf00268..3ca58b139bd3af1947ae7f063060e11d2ea7d577 100644 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -21,7 +21,7 @@ namespace tensorrt { class MulOpConverter : public OpConverter { public: MulOpConverter() {} - void operator()(const framework::OpDesc& op) override { + void operator()(const framework::proto::OpDesc& op) override { LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; } }; diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 77c788550b2c7df1f483b926661789b2a54d8fff..abc9ebf472498f6653d5bb1113ae2f3ce7e5a923 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -31,10 +31,10 @@ namespace tensorrt { class OpConverter { public: OpConverter() {} - virtual void operator()(const framework::OpDesc& op) {} + virtual void operator()(const framework::proto::OpDesc& op) {} - void Run(const framework::OpDesc& op, TensorRTEngine* engine) { - std::string type = op.Type(); + void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) { + std::string type = op.type(); auto* it = Registry::Lookup(type); PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type); it->SetEngine(engine); @@ -42,14 +42,16 @@ class OpConverter { } // convert fluid op to tensorrt layer - void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) { + void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) { OpConverter::Run(op, engine); } // convert fluid block to tensorrt network - void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) { - for (auto op : block.AllOps()) { - OpConverter::Run(*op, engine); + void ConvertBlock(const framework::proto::BlockDesc& block, + TensorRTEngine* engine) { + for (size_t i = 0; i < block.ops_size(); i++) { + const auto& op = block.ops(i); + OpConverter::Run(op, engine); } } diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index 23e3435c21725328d3765fae0d158a83ac21478b..ec33f97c8240dfc09a203d68599bffe78a4abb12 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/tensorrt/convert/io_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" @@ -26,7 +27,7 @@ namespace paddle { namespace inference { namespace tensorrt { -void Compare(float input, float expect) { +void Compare(const std::string op_type, float input, float expect) { framework::Scope scope; platform::CUDAPlace place; platform::CUDADeviceContext ctx(place); @@ -35,6 +36,7 @@ void Compare(float input, float expect) { auto x_var = scope.Var("X"); auto x_tensor = x_var->GetMutable(); x_tensor->Resize({1, 1}); + x_tensor->mutable_data(place); std::vector init; init.push_back(input); framework::TensorFromVector(init, ctx, x_tensor); @@ -45,14 +47,15 @@ void Compare(float input, float expect) { out_tensor->mutable_data(place); framework::OpDesc op_desc; - op_desc.SetType("relu"); + op_desc.SetType(op_type); op_desc.SetInput("X", {"X"}); op_desc.SetOutput("Out", {"Out"}); - auto relu_op = framework::OpRegistry::CreateOp(op_desc); + auto op = framework::OpRegistry::CreateOp(*op_desc.Proto()); // run fluid op - relu_op->Run(scope, place); + op->Run(scope, place); + // get fluid output std::vector out1; framework::TensorToVector(*out_tensor, ctx, &out1); @@ -63,21 +66,28 @@ void Compare(float input, float expect) { engine->InitNetwork(); engine->DeclareInput("X", nvinfer1::DataType::kFLOAT, nvinfer1::DimsCHW{1, 1, 1}); - + // convert op OpConverter op_converter; - op_converter.ConvertOp(op_desc, engine); + op_converter.ConvertOp(*op_desc.Proto(), engine); engine->DeclareOutput("Out"); engine->FreezeNetwork(); - engine->SetInputFromCPU("X", &input, 1 * sizeof(float)); - // run tensorrt op + // convert LoDTensor to ITensor + size_t size = x_tensor->memory_size(); + EngineIOConverter::ConvertInput(op_type, *x_tensor, + engine->buffer("X").buffer, size, &stream); + // run tensorrt Outp engine->Execute(1); - - float out2; - engine->GetOutputInCPU("Out", &out2, 1 * sizeof(float)); - - ASSERT_EQ(out1[0], out2); + // convert ITensor to LoDTensor + EngineIOConverter::ConvertOutput(op_type, engine->buffer("Out").buffer, + out_tensor, size, &stream); + // get tensorrt output + std::vector out2; + framework::TensorToVector(*out_tensor, ctx, &out2); + + // compare + ASSERT_EQ(out1[0], out2[0]); ASSERT_EQ(out1[0], expect); delete engine; @@ -85,8 +95,8 @@ void Compare(float input, float expect) { } TEST(OpConverter, ConvertRelu) { - Compare(1, 1); // relu(1) = 1 - Compare(-5, 0); // relu(-5) = 0 + Compare("relu", 1, 1); // relu(1) = 1 + Compare("relu", -5, 0); // relu(-5) = 0 } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc index afcc516e6b76d58e37ce0e60746704cf3933fac7..8f91309a0a00d5131268f026c319e25ba3cb964a 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc @@ -12,40 +12,63 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/tensorrt/convert/io_converter.h" -#include - namespace paddle { namespace inference { namespace tensorrt { -class EngineInputConverterTester : public ::testing::Test { - public: - void SetUp() override { tensor.Resize({10, 10}); } +void IOConverterTester(const platform::DeviceContext& ctx) { + cudaStream_t stream; + ASSERT_EQ(0, cudaStreamCreate(&stream)); - framework::LoDTensor tensor; -}; + // init fluid in_tensor + framework::LoDTensor in_tensor; + in_tensor.Resize({10, 10}); + auto place = ctx.GetPlace(); + in_tensor.mutable_data(place); + std::vector init; + for (int64_t i = 0; i < 10 * 10; ++i) { + init.push_back(i); + } + framework::TensorFromVector(init, ctx, &in_tensor); -TEST_F(EngineInputConverterTester, DefaultCPU) { + // init tensorrt buffer void* buffer; - tensor.mutable_data(platform::CPUPlace()); - ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); + size_t size = in_tensor.memory_size(); + ASSERT_EQ(cudaMalloc(&buffer, size), 0); - cudaStream_t stream; - EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(), - &stream); + // convert fluid in_tensor to tensorrt buffer + EngineIOConverter::ConvertInput("test", in_tensor, buffer, size, &stream); + + // convert tensorrt buffer to fluid out_tensor + framework::LoDTensor out_tensor; + out_tensor.Resize({10, 10}); + out_tensor.mutable_data(place); + EngineIOConverter::ConvertOutput("test", buffer, &out_tensor, size, &stream); + + // compare in_tensor and out_tensor + std::vector result; + framework::TensorToVector(out_tensor, ctx, &result); + EXPECT_EQ(init.size(), result.size()); + for (size_t i = 0; i < init.size(); i++) { + EXPECT_EQ(init[i], result[i]); + } + cudaStreamDestroy(stream); } -TEST_F(EngineInputConverterTester, DefaultGPU) { - void* buffer; - tensor.mutable_data(platform::CUDAPlace()); - ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); +TEST(EngineIOConverterTester, DefaultCPU) { + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + IOConverterTester(ctx); +} - cudaStream_t stream; - EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(), - &stream); +TEST(EngineIOConverterTester, DefaultGPU) { + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); + IOConverterTester(ctx); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index aa5fb726f1129eda65a6f39791330b795aad660d..8d66543eb7637c5a8ae670b89ef5996954ba2e7b 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) { conv2d_op->SetType("conv2d"); OpConverter converter; - converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/); + converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index c4fd1e298b0daea85db2a407d04ad2d7bcdee0f0..60c761c5281e2f535aab0200c93fb738addcdb87 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/inference/tests/test_helper.h" -DEFINE_string(data_set, "cifar10", "Data set to test"); DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model."); DEFINE_int32(batch_size, 1, "Batch size of input data"); @@ -35,19 +34,19 @@ TEST(inference, image_classification) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + const bool is_combined = false; + std::vector> feed_target_shapes = + GetFeedTargetShapes(dirname, is_combined); + paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [0.0, 1.0]. - if (FLAGS_data_set == "cifar10") { - SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32}, - static_cast(0), static_cast(1)); - } else if (FLAGS_data_set == "imagenet") { - SetupTensor(&input, {FLAGS_batch_size, 3, 224, 224}, - static_cast(0), static_cast(1)); - } else { - LOG(FATAL) << "Only cifar10 or imagenet is supported."; - } - + feed_target_shapes[0][0] = FLAGS_batch_size; + paddle::framework::DDim input_dims = + paddle::framework::make_ddim(feed_target_shapes[0]); + LOG(INFO) << input_dims; + SetupTensor(&input, input_dims, static_cast(0), + static_cast(1)); std::vector cpu_feeds; cpu_feeds.push_back(&input); @@ -60,7 +59,7 @@ TEST(inference, image_classification) { LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "Batch size is " << FLAGS_batch_size; TestInference( - dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined); LOG(INFO) << output1.dims(); } @@ -73,7 +72,7 @@ TEST(inference, image_classification) { LOG(INFO) << "--- GPU Runs: ---"; LOG(INFO) << "Batch size is " << FLAGS_batch_size; TestInference( - dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); + dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat, is_combined); LOG(INFO) << output2.dims(); if (!FLAGS_skip_cpu) { diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index af2a7a5620487a10c1df6152fc4e4bf67b150752..b02e5c99f00eaf03c3753e43575cbc67e834774e 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -89,6 +89,50 @@ void CheckError(const paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } +std::unique_ptr InitProgram( + paddle::framework::Executor* executor, paddle::framework::Scope* scope, + const std::string& dirname, const bool is_combined = false) { + std::unique_ptr inference_program; + if (is_combined) { + // All parameters are saved in a single file. + // Hard-coding the file names of program and parameters in unittest. + // The file names should be consistent with that used in Python API + // `fluid.io.save_inference_model`. + std::string prog_filename = "__model_combined__"; + std::string param_filename = "__params_combined__"; + inference_program = + paddle::inference::Load(executor, scope, dirname + "/" + prog_filename, + dirname + "/" + param_filename); + } else { + // Parameters are saved in separate files sited in the specified + // `dirname`. + inference_program = paddle::inference::Load(executor, scope, dirname); + } + return inference_program; +} + +std::vector> GetFeedTargetShapes( + const std::string& dirname, const bool is_combined = false) { + auto place = paddle::platform::CPUPlace(); + auto executor = paddle::framework::Executor(place); + auto* scope = new paddle::framework::Scope(); + + auto inference_program = InitProgram(&executor, scope, dirname, is_combined); + auto& global_block = inference_program->Block(0); + + const std::vector& feed_target_names = + inference_program->GetFeedTargetNames(); + std::vector> feed_target_shapes; + for (size_t i = 0; i < feed_target_names.size(); ++i) { + auto* var = global_block.FindVar(feed_target_names[i]); + std::vector var_shape = var->GetShape(); + feed_target_shapes.push_back(var_shape); + } + + delete scope; + return feed_target_shapes; +} + template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, @@ -124,22 +168,7 @@ void TestInference(const std::string& dirname, paddle::platform::RecordEvent record_event( "init_program", paddle::platform::DeviceContextPool::Instance().Get(place)); - - if (is_combined) { - // All parameters are saved in a single file. - // Hard-coding the file names of program and parameters in unittest. - // The file names should be consistent with that used in Python API - // `fluid.io.save_inference_model`. - std::string prog_filename = "__model_combined__"; - std::string param_filename = "__params_combined__"; - inference_program = paddle::inference::Load( - &executor, scope, dirname + "/" + prog_filename, - dirname + "/" + param_filename); - } else { - // Parameters are saved in separate files sited in the specified - // `dirname`. - inference_program = paddle::inference::Load(&executor, scope, dirname); - } + inference_program = InitProgram(&executor, scope, dirname, is_combined); } // Disable the profiler and print the timing information paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c14a2b7786f9f7c06d59479d3bbce9c5d542e495..d38a9ce58726a1d045d6905354b0b592166c0110 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -186,6 +186,11 @@ endif() add_subdirectory(detail) if(WITH_DISTRIBUTE) + if(WITH_GPU) + op_library(gen_nccl_id_op DEPS nccl_common) + else() + set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) + endif() set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") op_library(send_op DEPS ${DISTRIBUTE_DEPS}) @@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) + cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor) else() - set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op) + set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op) endif() op_library(cross_entropy_op DEPS cross_entropy) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 661dfa69fe1580ff3890f12defcd124225be0c06..ae60ab15325ef101feb7270a4f5d840cb2112be0 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, // stub context SendProcessor* s = new SendProcessor(ch); s->Prepare(var_h, time_out); - s->response_call_back_ = NULL; + s->response_call_back_ = nullptr; auto call = s->stub_g_.PrepareUnaryCall( s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index f6229b71bc01a6de51f50f5fe880ada6e15e74dd..dabce7414d2f0dca74193f1cd10c341793c10ec9 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); class BaseProcessor { public: - explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; } + explicit BaseProcessor(std::shared_ptr ch) { + context_ = nullptr; + } virtual ~BaseProcessor() {} @@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor { ::grpc::GenericStub stub_g_; ::grpc::ByteBuffer reply_; - RequestSendCallBack response_call_back_ = NULL; + RequestSendCallBack response_call_back_ = nullptr; }; typedef std::function diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 7f9cae21ccca8dd51f9fbe98148d01a51ac6eb84..18f1bc53d0f561f412a5bbbe018bc3d427ac9ef9 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -47,6 +47,7 @@ class AsyncGRPCServer final { explicit AsyncGRPCServer(const std::string &address, bool sync_mode) : address_(address), sync_mode_(sync_mode), ready_(0) {} + ~AsyncGRPCServer() {} void WaitServerReady(); void RunSyncUpdate(); diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index fffa9ae7a43ea5cd7b2bda6fbbf6ef9f7d23009d..9478c5702bcbf99fc88207b8c4843dbccf8a5925 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -32,6 +32,7 @@ service SendRecvService { enum VarType { LOD_TENSOR = 0; SELECTED_ROWS = 1; + NCCL_ID = 2; } // NOTICE(gongwb):don't modify this proto if you are not diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 1a8a1af20fa446dbd537944409ef0ca1e3e9116f..07c43554bc6a0d71d688a5a5772d0ab3d2de319a 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -14,6 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#ifdef PADDLE_WITH_CUDA +#include +#endif #include #include // NOLINT @@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, } else if (var->IsType()) { request.set_type(::sendrecv::SELECTED_ROWS); GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size); +#ifdef PADDLE_WITH_CUDA + } else if (var->IsType()) { + request.set_type(::sendrecv::NCCL_ID); +#endif } else { PADDLE_THROW("Serialize does not support type: %s", typeid(var->Type()).name()); @@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void* buf = buffer.get(); ProtoEncodeHelper e(static_cast(buf), 1024); e.WriteRawBytes(std::string(header.data(), header.size())); +// NCCLID is copied directly to the message, return bytebuffer +// with only one slice if serializing NCCLID. +#ifdef PADDLE_WITH_CUDA + if (var->IsType()) { + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, + NCCL_UNIQUE_ID_BYTES); + const ncclUniqueId& uid = var->Get(); + e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES)); + + // for serialize NCCL_ID + ::grpc::Slice slices(e.size()); + memcpy(const_cast(slices.begin()), e.data(), e.size()); + ::grpc::ByteBuffer tmp(&slices, 1); + msg->Swap(&tmp); + return; + } +#endif + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); // steal reference of tensor data ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 99602a05d023f30c2eed8df25e7534fdc9ef2ced..462e303096e609c6797ca8cc16266ec3621623fc 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -17,6 +17,9 @@ #include #include #include +#ifdef PADDLE_WITH_CUDA +#include +#endif #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" @@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) { } case sendrecv::VariableMessage::kSerializedFieldNumber: { PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || - meta_.type() == sendrecv::LOD_TENSOR) && + meta_.type() == sendrecv::LOD_TENSOR || + meta_.type() == sendrecv::NCCL_ID) && meta_.varname() != "", "meta info should be got first!"); @@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) { return tag; } + if (meta_.type() == sendrecv::NCCL_ID) { +#ifdef PADDLE_WITH_CUDA + auto* var = scope_->FindVar(meta_.varname()); + if (var != nullptr) { + ncclUniqueId* id = var->GetMutable(); + if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal, + num_bytes)) { + return tag; + } + } + break; +#else + PADDLE_THROW("Not compiled with CUDA!"); +#endif + } + framework::DDim dims = GetDims(meta_.dims()); if (meta_.type() == sendrecv::LOD_TENSOR) { PADDLE_ENFORCE(meta_.lod_size() >= 0, diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5678f63466d368b3dd59380c18f9625cabd368b --- /dev/null +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace operators { + +class GenNCCLIdOp : public framework::OperatorBase { + public: + GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + // put nccl id in CPUPlace + auto& dev_ctx = *pool.Get(platform::CPUPlace()); + int trainer_id = Attr("trainer_id"); + framework::Scope& local_scope = scope.NewScope(); + + if (trainer_id == 0) { + GenerateAndSend(&local_scope, dev_ctx); + } else { + GetIdByServer(&local_scope, dev_ctx); + } + } + + private: + void GenerateAndSend(framework::Scope* scope, + const platform::DeviceContext& dev_ctx) const { + auto var = scope->FindVar(NCCL_ID_VARNAME); + PADDLE_ENFORCE_NOT_NULL(var); + auto id = var->GetMutable(); + PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id)); + + std::vector endpoint_list = + Attr>("endpoint_list"); + detail::RPCClient client; + for (auto& ep : endpoint_list) { + VLOG(3) << "sending nccl id to " << ep; + client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME); + } + client.Wait(); + VLOG(3) << "sending completed..."; + } + + void GetIdByServer(framework::Scope* scope, + const platform::DeviceContext& dev_ctx) const { + std::string endpoint = Attr("endpoint"); + // NOTE: Can not use unique_ptr here because the default + // deleter will call GRPC Server's base class's dtor and + // that will cause a wired crash. + detail::AsyncGRPCServer rpc_service(endpoint, true); + framework::ProgramDesc empty_program; + framework::Executor executor(dev_ctx.GetPlace()); + rpc_service.SetScope(scope); + rpc_service.SetDevCtx(&dev_ctx); + rpc_service.SetProgram(&empty_program); + rpc_service.SetExecutor(&executor); + + std::thread server_thread( + std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service)); + rpc_service.SetCond(0); + VLOG(3) << "start getting nccl id from trainer 0..."; + auto recv = rpc_service.Get(); + VLOG(3) << "got nccl id and stop server..."; + rpc_service.ShutDown(); + VLOG(3) << "rpc server stopped"; + server_thread.join(); + } +}; + +class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces."); + AddComment(R"DOC( +GenNCCLId operator + +For trainer 0: generate a new UniqueId and send it to all the other trainers. +For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. +)DOC"); + AddAttr("endpoint", + "(string), e.g. 127.0.0.1:6175 " + "current listen endpoint"); + AddAttr>( + "endpoint_list", + "['trainer1_ip:port', 'trainer2_ip:port', ...] " + "list of trainer endpoints start from trainer 1") + .SetDefault({}); + AddAttr("trainer_id", + "(int default 0) " + "The index of the trainer in distributed training.") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker); diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/fluid/operators/math/sequence2batch.h index 0abda999a52bcbb94e6503692bd11aff26e849ba..62e6307ae9f4236a38c49daaf09fc05c54268159 100644 --- a/paddle/fluid/operators/math/sequence2batch.h +++ b/paddle/fluid/operators/math/sequence2batch.h @@ -64,18 +64,22 @@ class LoDTensor2BatchFunctor { bool is_reverse = false) const { if (!is_cal_batch_lod) { auto lods = batch->lod(); - PADDLE_ENFORCE_GT(lods.size(), 2UL); - PADDLE_ENFORCE_EQ(lods[1].size(), - static_cast(lod_tensor.dims()[0])); + PADDLE_ENFORCE_GT(lods.size(), 2UL, + "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."); + PADDLE_ENFORCE_EQ( + lods[1].size(), static_cast(lod_tensor.dims()[0]), + "The LoD information should be consistent with the dims."); CopyMatrixRowsFunctor to_batch; to_batch(context, lod_tensor, lods[1], batch, true); return; } auto lods = lod_tensor.lod(); - auto lod = lods[0]; PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); + auto lod = lods[0]; + std::vector seq_info; for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { int length = lod[seq_id + 1] - lod[seq_id]; @@ -157,9 +161,12 @@ class Batch2LoDTensorFunctor { const framework::LoDTensor& batch, framework::LoDTensor* lod_tensor) const { auto in_lod = batch.lod(); - PADDLE_ENFORCE_GT(in_lod.size(), 2UL); - PADDLE_ENFORCE_EQ(in_lod[1].size(), - static_cast(lod_tensor->dims()[0])); + PADDLE_ENFORCE_GT(in_lod.size(), 2UL, + "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."); + PADDLE_ENFORCE_EQ( + in_lod[1].size(), static_cast(lod_tensor->dims()[0]), + "The LoD information should be consistent with the dims."); CopyMatrixRowsFunctor to_seq; to_seq(context, batch, in_lod[1], lod_tensor, false); } diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index ccd7063fe69e0f21b4d2a821bb70902b39c9b9de..3dd8c7c11eca241e747bfa129962032d882ce44c 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -92,14 +92,16 @@ class ReshapeOp : public framework::OperatorWithKernel { } if (unk_dim_idx != -1) { - output_shape[unk_dim_idx] = -in_size / capacity; - // in_size < 0 and is un-determinate in compile time, skip the check, - // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], - // capacity = -24, in_size = -8, output_shape[0] = 0 - // the following check will fail. if (in_size > 0) { + // in_size < 0 and is un-determinate in compile time, skip the check, + // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, in_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -in_size / capacity; PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size, "Invalid shape is given."); + } else { + output_shape[unk_dim_idx] = -1; } } else { PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given."); @@ -122,7 +124,10 @@ class ReshapeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const { auto *out = ctx.Output("Out"); auto *in = ctx.Input("X"); - auto *shape_tensor = ctx.Input("Shape"); + + auto *shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; framework::DDim out_dims = out->dims(); diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc new file mode 100644 index 0000000000000000000000000000000000000000..bbae1d54aa3524fd45cb8ab13c86df8d54b8e643 --- /dev/null +++ b/paddle/fluid/operators/test_send_nccl_id.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include // NOLINT + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/listen_and_serv_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/nccl_helper.h" +#include "paddle/fluid/string/printf.h" + +USE_NO_KERNEL_OP(listen_and_serv); + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; +namespace detail = paddle::operators::detail; +namespace string = paddle::string; + +std::unique_ptr rpc_service; + +void StartServer(std::atomic* initialized) { + f::Scope scope; + p::CPUPlace place; + scope.Var(NCCL_ID_VARNAME); + p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(p::CPUPlace()); + + rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true)); + + f::ProgramDesc empty_program; + f::Executor executor(dev_ctx.GetPlace()); + rpc_service->SetScope(&scope); + rpc_service->SetDevCtx(&dev_ctx); + rpc_service->SetProgram(&empty_program); + rpc_service->SetExecutor(&executor); + + std::thread server_thread( + std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get())); + *initialized = true; + rpc_service->SetCond(0); + auto recv = rpc_service->Get(); + LOG(INFO) << "got nccl id and stop server..."; + rpc_service->ShutDown(); + server_thread.join(); +} + +TEST(SendNcclId, Normal) { + std::atomic initialized{false}; + std::thread server_thread(StartServer, &initialized); + while (!initialized) { + } + // wait server to start + // sleep(2); + rpc_service->WaitServerReady(); + + f::Scope scope; + p::CPUPlace place; + p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(p::CPUPlace()); + + auto var = scope.Var(NCCL_ID_VARNAME); + // var->SetType(f::proto::VarType_Type_RAW); + auto id = var->GetMutable(); + p::dynload::ncclGetUniqueId(id); + + int port = rpc_service->GetSelectedPort(); + std::string ep = string::Sprintf("127.0.0.1:%d", port); + detail::RPCClient client; + + client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); + client.Wait(); + server_thread.join(); + auto* ptr = rpc_service.release(); + delete ptr; +} diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 0013597fd516d15c7d502370eec77e1a6a5dca88..e30c1a9ebf08365a9856fb32b1ce5790869e2b33 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -14,12 +14,15 @@ #pragma once +#include #include // NOLINT #include #include #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" +#define NCCL_ID_VARNAME "NCCLID" + namespace paddle { namespace platform { @@ -73,7 +76,9 @@ struct NCCLContextMap { std::unordered_map contexts_; std::vector order_; - explicit NCCLContextMap(const std::vector &places) { + explicit NCCLContextMap(const std::vector &places, + ncclUniqueId *nccl_id = nullptr, + size_t num_trainers = 1, size_t trainer_id = 0) { PADDLE_ENFORCE(!places.empty()); order_.reserve(places.size()); for (auto &p : places) { @@ -85,18 +90,34 @@ struct NCCLContextMap { order_.size(), contexts_.size(), "NCCL Context Map does not support contain two or more same device"); - if (places.size() > 1) { - std::unique_ptr comms(new ncclComm_t[order_.size()]); + if (places.size() <= 1) { + return; + } + std::unique_ptr comms(new ncclComm_t[order_.size()]); + // if pass nccl_id here, can assume we are doing multi node training + if (nccl_id == nullptr) { + std::lock_guard guard(NCCLGroupGuard::NCCLMutex()); + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + comms.get(), static_cast(order_.size()), order_.data())); + } else { + PADDLE_ENFORCE_GT(num_trainers, 1); + // TODO(wuyi): need to ensure each node have same number of GPUs { - std::lock_guard guard(NCCLGroupGuard::NCCLMutex()); - PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - comms.get(), static_cast(order_.size()), order_.data())); - } - int i = 0; - for (auto &dev_id : order_) { - contexts_.at(dev_id).comm_ = comms[i++]; + int nranks = num_trainers * order_.size(); + NCCLGroupGuard gurad; + for (auto &gpu_id : order_) { + int rank = trainer_id * order_.size() + gpu_id; + VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks; + PADDLE_ENFORCE(cudaSetDevice(gpu_id)); + PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( + comms.get() + gpu_id, nranks, *nccl_id, rank)); + } } } + int i = 0; + for (auto &dev_id : order_) { + contexts_.at(dev_id).comm_ = comms[i++]; + } } NCCLContextMap(const NCCLContextMap &other) = delete; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3e2eed31b446b83843fba943e4f2bc9e3787d7f6..b62291a99d34457dd17bf2bcafc1fc611419f086 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -503,12 +503,13 @@ All parameter, weight, gradient are variables in Paddle. const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope, std::vector &local_scopes, bool allow_op_delay, bool use_default_grad_scale, - bool balance_parameter_opt_between_cards) { + bool balance_parameter_opt_between_cards, size_t num_trainers, + size_t trainer_id) { new (&self) ParallelExecutor( num_threads, use_event, places, params, bcast_vars, main_program, loss_var_name, scope, local_scopes, allow_op_delay, use_default_grad_scale, - balance_parameter_opt_between_cards); + balance_parameter_opt_between_cards, num_trainers, trainer_id); }) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) // NOTE: even we return a vec* to Python use reference policy. diff --git a/patches/mkldnn.hpp b/patches/mkldnn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fe01ad8a10ebd223da75bf857617c4ad36b2634e --- /dev/null +++ b/patches/mkldnn.hpp @@ -0,0 +1,4252 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_HPP +#define MKLDNN_HPP + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include +#include +#include +#include +#include +#include + +#include "mkldnn.h" +#endif + +namespace mkldnn { + +/// @addtogroup cpp_api C++ API +/// @{ + +/// @addtogroup cpp_api_utils Utils +/// @{ + +/// A class that provides the destructor for an Intel(R) MKL-DNN C handle +template +class handle_traits {}; + +/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base +/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and +/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class +/// can be passed by value. This class enables wrapping: +/// - Newly constructed handles. +/// @n In this case, the constructed handle uses reference counting provided +/// by @p std::shared_ptr with a proper deleter function specified through +/// the @p handle_traits class. +/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for +/// example, through #mkldnn_primitive_get_output()). +/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a +/// deleter because it is assumed that the handle wrapper for the original +/// object deletes the handle (this model is similar to @p std::weak_ptr). +template > +class handle { +private: + std::shared_ptr::type> _data; + handle(const handle &&) = delete; + handle &operator=(const handle &&other) = delete; + +protected: + /// Constructs a C handle wrapper. + /// @param t The C handle to wrap. + /// @param weak A flag to specify whether to construct a weak wrapper. + handle(T t = 0, bool weak = false) : _data(0) { reset(t, weak); } + + bool operator==(const T other) const { return other == _data.get(); } + bool operator!=(const T other) const { return !(*this == other); } + +public: + handle(const handle &other) : _data(other._data) {} + handle &operator=(const handle &other) { + _data = other._data; + return *this; + } + /// Resets the value of a C handle. + /// @param t The new value of the C handle. + /// @param weak A flag to specify whether the wrapper should be weak. + void reset(T t, bool weak = false) { + auto dummy_destructor = [](T) { + return decltype(traits::destructor(0))(0); + }; + _data.reset(t, weak ? dummy_destructor : traits::destructor); + } + + /// Returns the value of the underlying C handle. + T get() const { return _data.get(); } + + bool operator==(const handle &other) const { + return other._data.get() == _data.get(); + } + bool operator!=(const handle &other) const { return !(*this == other); } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> +struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_desc_destroy; +}; + +template <> +struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_destroy; +}; +#endif + +/// Base class for all computational primitives. +class primitive : public handle { + friend struct error; + friend struct stream; + friend class primitive_at; + using handle::handle; + +public: + /// A proxy to C primitive kind enum + enum class kind { + undefined_primitive = mkldnn_undefined_primitive, + memory = mkldnn_memory, + view = mkldnn_view, + reorder = mkldnn_reorder, + concat = mkldnn_concat, + concat_inplace = mkldnn_concat_inplace, + sum = mkldnn_sum, + convolution = mkldnn_convolution, + deconvolution = mkldnn_deconvolution, + eltwise = mkldnn_eltwise, + relu = mkldnn_relu, + softmax = mkldnn_softmax, + pooling = mkldnn_pooling, + lrn = mkldnn_lrn, + batch_normalization = mkldnn_batch_normalization, + inner_product = mkldnn_inner_product, + convolution_relu = mkldnn_convolution_relu, + rnn = mkldnn_rnn, + }; + + /// A wrapper structure to specify a particular output of a primitive. + struct at { + /// The underlying C API structure. + mkldnn_primitive_at_t data; + /// Constructs a wrapper specifying @p aprimitive output with index @p + /// at. + /// + /// @param aprimitive The target primitive. + /// @param at The output index. + + at(const primitive &aprimitive, size_t at = 0) + : data(mkldnn_primitive_at(aprimitive.get(), at)) {} + /// Returns the specified output. + inline operator primitive() const; + }; + + /// Returns the descriptor of the underlying C API primitive + inline const_mkldnn_primitive_desc_t get_primitive_desc() const; + // TODO: use the C++ API wrapper structure. +}; + +inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { + return static_cast(akind); +} + +/// Intel(R) MKL-DNN exception class. +/// +/// This class captures the status returned by the failed C API function, error +/// message, and, optionally, handle of the primitive that caused the error. +struct error : public std::exception { + mkldnn_status_t status; + std::string message; + primitive error_primitive; + + /// Constructs an error instance. + /// + /// @param astatus The error status returned by the C API. + /// @param amessage The error message. + /// @param aerror_primitive (optional) A C handle of the primitive that + /// caused the error. + + error(mkldnn_status_t astatus, + std::string amessage, + mkldnn_primitive_t aerror_primitive = 0) + : status(astatus), + message(amessage), + error_primitive(aerror_primitive, true) {} + + /// A convenience function for wrapping calls to the C API. Checks the + /// return status and throws an #error in case of failure. + /// + /// @param status The error status returned by the C API. + /// @param message The error message. + /// @param error_primitive (optional) A C handle of the primitive that + /// caused the error. + + static void wrap_c_api(mkldnn_status_t status, + std::string message, + mkldnn_primitive_t *error_primitive = 0) { + if (status != mkldnn_success) { + if (nullptr != error_primitive) + throw error(status, message, *error_primitive); + else + throw error(status, message, nullptr); + } + } +}; + +inline primitive::at::operator primitive() const { + const_mkldnn_primitive_t output; + error::wrap_c_api( + mkldnn_primitive_get_output(data.primitive, data.output_index, &output), + "could not get an output primitive"); + return primitive(const_cast(output), true); +} + +const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { + const_mkldnn_primitive_desc_t pd; + error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), + "could not get primitive descriptor by primitive"); + return pd; +} +/// @} + +/// @addtogroup cpp_api_enums Common data types and enumerations +/// @{ + +enum round_mode { + round_nearest = mkldnn_round_nearest, + round_down = mkldnn_round_down, +}; + +inline mkldnn_round_mode_t convert_to_c(round_mode mode) { + return static_cast(mode); +} + +enum padding_kind { zero = mkldnn_padding_zero }; + +inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { + return static_cast(kind); +} + +enum prop_kind { + forward_training = mkldnn_forward_training, + forward_scoring = mkldnn_forward_scoring, + forward_inference = mkldnn_forward_inference, + forward = mkldnn_forward, + backward = mkldnn_backward, + backward_data = mkldnn_backward_data, + backward_weights = mkldnn_backward_weights, + backward_bias = mkldnn_backward_bias +}; + +inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { + return static_cast(kind); +} + +enum algorithm { + algorithm_undef = mkldnn_alg_kind_undef, + convolution_direct = mkldnn_convolution_direct, + convolution_winograd = mkldnn_convolution_winograd, + deconvolution_direct = mkldnn_deconvolution_direct, + deconvolution_winograd = mkldnn_deconvolution_winograd, + eltwise_relu = mkldnn_eltwise_relu, + eltwise_tanh = mkldnn_eltwise_tanh, + eltwise_elu = mkldnn_eltwise_elu, + eltwise_square = mkldnn_eltwise_square, + eltwise_abs = mkldnn_eltwise_abs, + eltwise_sqrt = mkldnn_eltwise_sqrt, + eltwise_linear = mkldnn_eltwise_linear, + eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, + eltwise_soft_relu = mkldnn_eltwise_soft_relu, + eltwise_logistic = mkldnn_eltwise_logistic, + lrn_across_channels = mkldnn_lrn_across_channels, + lrn_within_channel = mkldnn_lrn_within_channel, + pooling_max = mkldnn_pooling_max, + pooling_avg = mkldnn_pooling_avg, + pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, + pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, + vanilla_rnn = mkldnn_vanilla_rnn, + vanilla_lstm = mkldnn_vanilla_lstm, + vanilla_gru = mkldnn_vanilla_gru, +}; + +inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { + return static_cast(aalgorithm); +} + +enum batch_normalization_flag { + use_global_stats = mkldnn_use_global_stats, + use_scale_shift = mkldnn_use_scaleshift, + omit_stats = mkldnn_omit_stats, + fuse_bn_relu = mkldnn_fuse_bn_relu +}; + +inline mkldnn_batch_normalization_flag_t convert_to_c( + batch_normalization_flag aflag) { + return static_cast(aflag); +} + +enum rnn_direction { + unidirectional_left2right = mkldnn_unidirectional_left2right, + unidirectional_right2left = mkldnn_unidirectional_right2left, + unidirectional = mkldnn_unidirectional, + bidirectional_concat = mkldnn_bidirectional_concat, + bidirectional_sum = mkldnn_bidirectional_sum, +}; + +inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { + return static_cast(adir); +} + +enum query { + undef = mkldnn_query_undef, + + eengine = mkldnn_query_engine, + primitive_kind = mkldnn_query_primitive_kind, + + num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, + num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, + + time_estimate_f64 = mkldnn_query_time_estimate_f64, + memory_consumption_s64 = mkldnn_query_memory_consumption_s64, + + impl_info_str = mkldnn_query_impl_info_str, + + memory_d = mkldnn_query_memory_d, + convolution_d = mkldnn_query_convolution_d, + deconvolution_d = mkldnn_query_deconvolution_d, + eltwise_d = mkldnn_query_eltwise_d, + relu_d = mkldnn_query_relu_d, + softmax_d = mkldnn_query_softmax_d, + pooling_d = mkldnn_query_pooling_d, + lrn_d = mkldnn_query_lrn_d, + batch_normalization_d = mkldnn_query_batch_normalization_d, + inner_product_d = mkldnn_query_inner_product_d, + convolution_relu_d = mkldnn_query_convolution_relu_d, + rnn_d = mkldnn_query_rnn_d, + + input_pd = mkldnn_query_input_pd, + output_pd = mkldnn_query_output_pd, + src_pd = mkldnn_query_src_pd, + diff_src_pd = mkldnn_query_diff_src_pd, + weights_pd = mkldnn_query_weights_pd, + diff_weights_pd = mkldnn_query_diff_weights_pd, + dst_pd = mkldnn_query_dst_pd, + diff_dst_pd = mkldnn_query_diff_dst_pd, + workspace_pd = mkldnn_query_workspace_pd, +}; + +inline mkldnn_query_t convert_to_c(query aquery) { + return static_cast(aquery); +} + +/// @} + +/// @addtogroup cpp_api_attr Attributes +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> +struct handle_traits { + static constexpr auto destructor = &mkldnn_post_ops_destroy; +}; +#endif + +struct post_ops : public handle { + post_ops() { + mkldnn_post_ops_t result; + error::wrap_c_api(mkldnn_post_ops_create(&result), + "could not create post operation sequence"); + reset(result); + } + + int len() const { return mkldnn_post_ops_len(get()); } + + primitive::kind kind(int index) const { + error::wrap_c_api(index < len() ? mkldnn_success : mkldnn_invalid_arguments, + "post_ops index is out of range"); + return static_cast(mkldnn_post_ops_get_kind(get(), index)); + } + + void append_sum(float scale = 1.) { + error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), + "could not append sum"); + } + + void get_params_sum(int index, float &scale) const { + error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), + "could not get sum params"); + } + + void append_eltwise(float scale, algorithm alg, float alpha, float beta) { + error::wrap_c_api(mkldnn_post_ops_append_eltwise( + get(), scale, convert_to_c(alg), alpha, beta), + "could not append eltwise"); + } + + void get_params_eltwise(int index, + float &scale, + algorithm &alg, + float &alpha, + float &beta) const { + mkldnn_alg_kind_t c_alg; + error::wrap_c_api(mkldnn_post_ops_get_params_eltwise( + get(), index, &scale, &c_alg, &alpha, &beta), + "could not get eltwise params"); + alg = static_cast(c_alg); + } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> +struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_attr_destroy; +}; +#endif + +struct primitive_attr : public handle { + primitive_attr() { + mkldnn_primitive_attr_t result; + error::wrap_c_api(mkldnn_primitive_attr_create(&result), + "could not create a primitive attr"); + reset(result); + } + + round_mode get_int_output_round_mode() const { + mkldnn_round_mode_t result; + error::wrap_c_api( + mkldnn_primitive_attr_get_int_output_round_mode(get(), &result), + "could not get int output round mode"); + return round_mode(result); + } + + void set_int_output_round_mode(round_mode mode) { + error::wrap_c_api(mkldnn_primitive_attr_set_int_output_round_mode( + get(), mkldnn::convert_to_c(mode)), + "could not set int output round mode"); + } + + void get_output_scales(int &mask, std::vector &scales) const { + int count, c_mask; + const float *c_scales; + error::wrap_c_api(mkldnn_primitive_attr_get_output_scales( + get(), &count, &c_mask, &c_scales), + "could not get int output scales"); + scales.resize(count); + + mask = c_mask; + for (int c = 0; c < count; ++c) scales[c] = c_scales[c]; + } + + void set_output_scales(int mask, const std::vector &scales) { + error::wrap_c_api(mkldnn_primitive_attr_set_output_scales( + get(), (int)scales.size(), mask, &scales[0]), + "could not set int output scales"); + } + + const post_ops get_post_ops() const { + post_ops result; + const_mkldnn_post_ops_t c_result; + error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), + "could not get post operation sequence"); + result.reset(const_cast(c_result), true); + return result; + } + + void set_post_ops(post_ops ops) { + error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), + "could not set post operation sequence"); + } +}; + +/// @} + +/// @addtogroup cpp_api_engine Engine +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> +struct handle_traits { + static constexpr auto destructor = &mkldnn_engine_destroy; +}; +#endif + +/// An execution engine. +struct engine : public handle { + friend class primitive; + // gcc bug??? using handle::handle; + + /// Kinds of engines + enum kind { + /// An unspecified engine + any = mkldnn_any_engine, + /// CPU engine + cpu = mkldnn_cpu, + }; + + /// Returns the number of engines of a certain kind. + /// + /// @param akind The kind of engines to count. + + static size_t get_count(kind akind) { + return mkldnn_engine_get_count(convert_to_c(akind)); + } + + /// Constructs an engine. + /// + /// @param akind The kind of engine to construct. + /// @param index The index of the engine. Must be less than the value + /// returned by #get_count() for this particular kind of engine. + + engine(kind akind, size_t index) { + mkldnn_engine_t aengine; + error::wrap_c_api( + mkldnn_engine_create(&aengine, convert_to_c(akind), index), + "could not create an engine"); + reset(aengine); + } + + explicit engine(const mkldnn_engine_t &aengine) : handle(aengine, true) {} + + engine(const handle &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query( + pd.get(), mkldnn::convert_to_c(eengine), 0, &engine_q), + "could not get engine from primitive_desc"); + reset(engine_q, true); + } + + template + static engine query(const primitive_desc &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query( + pd.get(), mkldnn::convert_to_c(eengine), 0, &engine_q), + "could not get engine from primitive_desc"); + + return engine(engine_q); + } + +private: + static mkldnn_engine_kind_t convert_to_c(kind akind) { + return static_cast(akind); + } +}; + +/// @} + +/// @addtogroup cpp_api_primitives Primitives +/// @{ + +/// @addtogroup cpp_api_memory Memory +/// @{ + +/// Memory primitive that describes the data. +struct memory : public primitive { +private: + std::shared_ptr _handle; + +public: + typedef std::vector::type> dims; + + template + static void validate_dims(std::vector v) { + if (v.size() > TENSOR_MAX_DIMS) + throw error(mkldnn_invalid_arguments, "invalid dimensions"); + } + + /// Data type specification. See #mkldnn_data_type_t for a detailed + /// description. + enum data_type { + data_undef = mkldnn_data_type_undef, + f32 = mkldnn_f32, + s32 = mkldnn_s32, + s16 = mkldnn_s16, + s8 = mkldnn_s8, + u8 = mkldnn_u8, + }; + + /// Memory format specification. See #mkldnn_memory_format_t + /// for a detailed description. + enum format { + format_undef = mkldnn_format_undef, + any = mkldnn_any, + blocked = mkldnn_blocked, + x = mkldnn_x, + nc = mkldnn_nc, + nchw = mkldnn_nchw, + nhwc = mkldnn_nhwc, + chwn = mkldnn_chwn, + nChw8c = mkldnn_nChw8c, + nChw16c = mkldnn_nChw16c, + ncdhw = mkldnn_ncdhw, + ndhwc = mkldnn_ndhwc, + nCdhw16c = mkldnn_nCdhw16c, + oi = mkldnn_oi, + io = mkldnn_io, + oihw = mkldnn_oihw, + ihwo = mkldnn_ihwo, + hwio = mkldnn_hwio, + oidhw = mkldnn_oidhw, + OIdhw16i16o = mkldnn_OIdhw16i16o, + OIdhw16o16i = mkldnn_OIdhw16o16i, + Oidhw16o = mkldnn_Oidhw16o, + Odhwi16o = mkldnn_Odhwi16o, + oIhw8i = mkldnn_oIhw8i, + oIhw16i = mkldnn_oIhw16i, + OIhw8i8o = mkldnn_OIhw8i8o, + OIhw16i16o = mkldnn_OIhw16i16o, + OIhw8o8i = mkldnn_OIhw8o8i, + OIhw16o16i = mkldnn_OIhw16o16i, + IOhw16o16i = mkldnn_IOhw16o16i, + OIhw8i16o2i = mkldnn_OIhw8i16o2i, + OIhw8o16i2o = mkldnn_OIhw8o16i2o, + OIhw4i16o4i = mkldnn_OIhw4i16o4i, + Oihw8o = mkldnn_Oihw8o, + Oihw16o = mkldnn_Oihw16o, + Ohwi8o = mkldnn_Ohwi8o, + Ohwi16o = mkldnn_Ohwi16o, + OhIw16o4i = mkldnn_OhIw16o4i, + goihw = mkldnn_goihw, + hwigo = mkldnn_hwigo, + gOIhw8i8o = mkldnn_gOIhw8i8o, + gOIhw16i16o = mkldnn_gOIhw16i16o, + gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, + gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, + gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, + gOihw8o = mkldnn_gOihw8o, + gOihw16o = mkldnn_gOihw16o, + gOhwi8o = mkldnn_gOhwi8o, + gOhwi16o = mkldnn_gOhwi16o, + Goihw8g = mkldnn_Goihw8g, + Goihw16g = mkldnn_Goihw16g, + gOIhw8o8i = mkldnn_gOIhw8o8i, + gOIhw16o16i = mkldnn_gOIhw16o16i, + gIOhw16o16i = mkldnn_gIOhw16o16i, + gOhIw16o4i = mkldnn_gOhIw16o4i, + goidhw = mkldnn_goidhw, + gOIdhw16i16o = mkldnn_gOIdhw16i16o, + gOIdhw16o16i = mkldnn_gOIdhw16o16i, + gOidhw16o = mkldnn_gOidhw16o, + gOdhwi16o = mkldnn_gOdhwi16o, + ntc = mkldnn_ntc, + tnc = mkldnn_tnc, + ldsnc = mkldnn_ldsnc, + ldigo = mkldnn_ldigo, + ldigo_p = mkldnn_ldigo_p, + ldgoi = mkldnn_ldgoi, + ldgoi_p = mkldnn_ldgoi_p, + ldgo = mkldnn_ldgo, + wino_fmt = mkldnn_wino_fmt, + format_last = mkldnn_format_last, + }; + + /// A memory descriptor. + struct desc { + friend struct memory; + /// The underlying C API data structure. + mkldnn_memory_desc_t data; + + /// Constructs a memory descriptor. + /// + /// @param adims Data dimensions + /// @param adata_type Data precision/type. + /// @param aformat Data layout format. + desc(dims adims, data_type adata_type, format aformat) { + validate_dims(adims); + error::wrap_c_api( + mkldnn_memory_desc_init(&data, + (int)adims.size(), + adims.size() == 0 ? nullptr : &adims[0], + convert_to_c(adata_type), + convert_to_c(aformat)), + "could not initialize a memory descriptor"); + } + + /// Constructs a memory descriptor from a C API data structure. + /// + /// @param adata A C API #mkldnn_memory_desc_t structure. + desc(const mkldnn_memory_desc_t &adata) : data(adata) {} + }; + + /// A memory primitive descriptor. + struct primitive_desc : public handle { + friend struct memory; + + // TODO: make private + primitive_desc() {} + + /// Constructs a memory primitive descriptor. + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_memory_primitive_desc_create( + &result, &adesc.data, aengine.get()), + "could not initialize a memory primitive descriptor"); + reset(result); + } + + /// Returns the memory primitive descriptor. + memory::desc desc() { + auto memory_d = mkldnn_primitive_desc_query_memory_d(get()); + return memory::desc(*memory_d); + } + + /// Returns the number of bytes required to allocate the memory described + /// including the padding area. + size_t get_size() const { + return mkldnn_memory_primitive_desc_get_size(get()); + } + + bool operator==(const primitive_desc &other) const { + return mkldnn_memory_primitive_desc_equal(get(), other.get()); + } + + bool operator!=(const primitive_desc &other) const { + return !operator==(other); + } + + engine get_engine() { return engine::query(*this); } + }; + + /// Constructs a memory primitive from a generic primitive. + /// + /// @param aprimitive The primitive to treat as memory. + memory(const primitive &aprimitive) : primitive(aprimitive) {} + /// Constructs a memory primitive. + /// + /// @param adesc Memory primitive descriptor. + memory(const primitive_desc &adesc) { + mkldnn_primitive_t result; + error::wrap_c_api( + mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr), + "could not create a memory primitive"); + reset(result); + auto _malloc = [](size_t size, int alignment) { + void *ptr; +#ifdef _WIN32 + ptr = _aligned_malloc(size, alignment); + int rc = ((ptr) ? 0 : errno); +#else + int rc = ::posix_memalign(&ptr, alignment, size); +#endif /* _WIN32 */ + return (rc == 0) ? (char *)ptr : nullptr; + }; + auto _free = [](char *p) { +#ifdef _WIN32 + _aligned_free((void *)p); +#else + ::free((void *)p); +#endif /* _WIN32 */ + }; + _handle.reset(_malloc(adesc.get_size(), 4096), _free); + set_data_handle(_handle.get()); + } + + memory(const primitive_desc &adesc, void *ahandle) { + mkldnn_primitive_t result; + error::wrap_c_api( + mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr), + "could not create a memory primitive"); + reset(result); + set_data_handle(ahandle); + } + + /// Returns the descriptor of the memory primitive. + primitive_desc get_primitive_desc() const { + primitive_desc adesc; + const_mkldnn_primitive_desc_t cdesc; + error::wrap_c_api( + mkldnn_primitive_get_primitive_desc(get(), &cdesc), + "could not get primitive descriptor from a memory primitive"); + /* FIXME: no const_cast should be here */ + adesc.reset(const_cast(cdesc), true); + return adesc; + } + + /// Returns a handle of the data contained in the memory primitive. On + /// the CPU engine, this is a pointer to the allocated memory. + inline void *get_data_handle() const { + void *handle; + error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), + "could not get native handle"); + return handle; + } + + inline void set_data_handle(void *handle) const { + error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), + "could not set native handle"); + } + + // Must go away or be private: + static mkldnn_data_type_t convert_to_c(data_type adata_type) { + return static_cast(adata_type); + } + static mkldnn_memory_format_t convert_to_c(format aformat) { + return static_cast(aformat); + } +}; + +inline memory::desc zero_md() { + mkldnn_memory_desc_t zero; + zero.primitive_kind = mkldnn_memory; + return memory::desc(zero); +} + +inline memory null_memory(engine eng) { + mkldnn::memory::desc zero = zero_md(); + return memory({zero, eng}, nullptr); +} + +inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) { + const_mkldnn_primitive_desc_t aprimitive_pd; + mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd); + const mkldnn_memory_desc_t *aprimitive_md = + mkldnn_primitive_desc_query_memory_d(aprimitive_pd); + + return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0)); +} + +inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { + return !(a == b); +} +inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { + return b == a; +} +inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { + return !(a == b); +} + +inline bool operator==(mkldnn_memory_format_t a, memory::format b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_memory_format_t a, memory::format b) { + return !(a == b); +} +inline bool operator==(memory::format a, mkldnn_memory_format_t b) { + return b == a; +} +inline bool operator!=(memory::format a, mkldnn_memory_format_t b) { + return !(a == b); +} + +/// @} + +/// @addtogroup cpp_api_reorder Reorder +/// @{ + +struct reorder : public primitive { + struct primitive_desc : public handle { + primitive_desc(const memory::primitive_desc &input, + const memory::primitive_desc &output) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create( + &result, input.get(), output.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory::primitive_desc &input, + const memory::primitive_desc &output, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create_v2( + &result, input.get(), output.get(), aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + engine get_engine() { return engine::query(*this); } + }; + + reorder(const primitive_desc &aprimitive_desc, + const primitive::at &input, + const memory &output) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {input.data}; + const_mkldnn_primitive_t outputs[] = {output.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a reorder primitive"); + reset(result); + } + + reorder(const primitive::at &input, const memory &output) { + auto input_mpd = memory(input).get_primitive_desc(); + auto output_mpd = output.get_primitive_desc(); + + auto reorder_d = primitive_desc(input_mpd, output_mpd); + + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {input.data}; + const_mkldnn_primitive_t outputs[] = {output.get()}; + error::wrap_c_api( + mkldnn_primitive_create(&result, reorder_d.get(), inputs, outputs), + "could not create a reorder primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_view View +/// @{ + +struct view : public primitive { + struct primitive_desc : public handle { + primitive_desc(const memory::primitive_desc &input, + memory::dims dims, + memory::dims offsets) { + mkldnn_primitive_desc_t result; + + error::wrap_c_api(mkldnn_view_primitive_desc_create( + &result, input.get(), &dims[0], &offsets[0]), + "could not create a view primitive descriptor"); + reset(result); + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + view(const primitive_desc &view_pd, primitive::at input) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {input.data}; + error::wrap_c_api( + mkldnn_primitive_create(&result, view_pd.get(), inputs, nullptr), + "could not create a view primitive"); + reset(result); + } + + view(memory input, memory::dims dims, memory::dims offsets) { + mkldnn_primitive_t result; + primitive_desc view_pd(input.get_primitive_desc(), dims, offsets); + mkldnn_primitive_at_t inputs[] = {primitive::at(input).data}; + error::wrap_c_api( + mkldnn_primitive_create(&result, view_pd.get(), inputs, nullptr), + "could not create a view primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_concat Concat +/// @{ + +struct concat : public primitive { + struct primitive_desc : public handle { + std::vector cpp_to_c( + std::vector inputs) { + std::vector c_api_inputs; + c_api_inputs.reserve(inputs.size()); + auto convert_to_c = [](memory::primitive_desc d) { return d.get(); }; + std::transform(inputs.begin(), + inputs.end(), + std::back_inserter(c_api_inputs), + convert_to_c); + return c_api_inputs; + } + + primitive_desc(const memory::desc &output, + int concat_dimension, + std::vector inputs) { + mkldnn_primitive_desc_t result; + + auto c_api_inputs = cpp_to_c(inputs); + + error::wrap_c_api( + mkldnn_concat_primitive_desc_create(&result, + &output.data, + (int)c_api_inputs.size(), + concat_dimension, + &c_api_inputs[0]), + "could not create a concat primitive descriptor"); + reset(result); + } + + primitive_desc(int concat_dimension, + std::vector inputs) { + mkldnn_primitive_desc_t result; + + auto c_api_inputs = cpp_to_c(inputs); + + error::wrap_c_api( + mkldnn_concat_primitive_desc_create(&result, + nullptr, + (int)c_api_inputs.size(), + concat_dimension, + &c_api_inputs[0]), + "could not create a concat primitive descriptor"); + reset(result); + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + concat(const primitive_desc &concat_pd, + std::vector &inputs, + const memory &output) { + mkldnn_primitive_t result; + + std::vector p_inputs; + for (size_t i = 0; i < inputs.size(); i++) + p_inputs.push_back(inputs[i].data); + const_mkldnn_primitive_t outputs[] = {output.get()}; + + error::wrap_c_api(mkldnn_primitive_create( + &result, concat_pd.get(), &p_inputs[0], outputs), + "could not create a concat primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_sum Sum +/// @{ + +struct sum : public primitive { + struct primitive_desc : public handle { + std::vector cpp_to_c( + std::vector inputs) { + std::vector c_api_inputs; + c_api_inputs.reserve(inputs.size()); + auto convert_to_c = [](memory::primitive_desc d) { return d.get(); }; + std::transform(inputs.begin(), + inputs.end(), + std::back_inserter(c_api_inputs), + convert_to_c); + return c_api_inputs; + } + + primitive_desc(const memory::desc &output, + const std::vector &scales, + std::vector inputs) { + mkldnn_primitive_desc_t result; + + auto c_api_inputs = cpp_to_c(inputs); + + error::wrap_c_api( + mkldnn_sum_primitive_desc_create(&result, + &output.data, + (int)c_api_inputs.size(), + &scales[0], + &c_api_inputs[0]), + "could not create a sum primitive descriptor"); + reset(result); + } + + primitive_desc(const std::vector &scales, + std::vector inputs) { + mkldnn_primitive_desc_t result; + + auto c_api_inputs = cpp_to_c(inputs); + + error::wrap_c_api( + mkldnn_sum_primitive_desc_create(&result, + nullptr, + (int)c_api_inputs.size(), + &scales[0], + &c_api_inputs[0]), + "could not create a sum primitive descriptor"); + reset(result); + } + + /** @deprecated: api backwards compatibility for double scales type */ + MKLDNN_DEPRECATED + primitive_desc(const memory::desc &output, + std::vector scale, + std::vector inputs) { + mkldnn_primitive_desc_t result; + + auto c_api_inputs = cpp_to_c(inputs); + auto scale_f = scale_to_float(scale); + + error::wrap_c_api( + mkldnn_sum_primitive_desc_create(&result, + &output.data, + (int)c_api_inputs.size(), + &scale_f[0], + &c_api_inputs[0]), + "could not create a sum primitive descriptor"); + reset(result); + } + + /** @deprecated: api backwards compatibility for double scales type */ + MKLDNN_DEPRECATED + primitive_desc(std::vector scale, + std::vector inputs) { + mkldnn_primitive_desc_t result; + + auto c_api_inputs = cpp_to_c(inputs); + auto scale_f = scale_to_float(scale); + + error::wrap_c_api( + mkldnn_sum_primitive_desc_create(&result, + nullptr, + (int)c_api_inputs.size(), + &scale_f[0], + &c_api_inputs[0]), + "could not create a sum primitive descriptor"); + reset(result); + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + sum(const primitive_desc &sum_pd, + std::vector &inputs, + const memory &output) { + mkldnn_primitive_t result; + + std::vector p_inputs; + for (size_t i = 0; i < inputs.size(); i++) + p_inputs.push_back(inputs[i].data); + const_mkldnn_primitive_t outputs[] = {output.get()}; + + error::wrap_c_api( + mkldnn_primitive_create(&result, sum_pd.get(), &p_inputs[0], outputs), + "could not create a sum primitive"); + reset(result); + } + +private: + static std::vector scale_to_float(const std::vector &vd) { + std::vector vf(vd.size()); + std::transform( + vd.begin(), vd.end(), vf.begin(), [=](double x) { return (float)x; }); + return vf; + } +}; + +/// @} + +/// @addtogroup cpp_api_convolution Convolution +/// @{ + +struct convolution_forward : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + &weights_desc.data, + &bias_desc.data, + &dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + &weights_desc.data, + nullptr, + &dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + &weights_desc.data, + &bias_desc.data, + &dst_desc.data, + &strides[0], + &dilates[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + &weights_desc.data, + nullptr, + &dst_desc.data, + &strides[0], + &dilates[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + }; + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a convolution forward primitive descriptor"); + reset(result); + } + + primitive_desc(const desc &adesc, + const primitive_attr &aattr, + const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create_v2( + &result, &adesc.data, aattr.get(), aengine.get(), nullptr), + "could not create a convolution forward primitive descriptor"); + reset(result); + } + + memory::primitive_desc src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + convolution_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const primitive::at &bias, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a convolution forward bias primitive"); + reset(result); + } + + convolution_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a convolution forward primitive"); + reset(result); + } +}; + +struct convolution_backward_data : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_convolution_backward_data_desc_init( + &data, + convert_to_c(aalgorithm), + &diff_src_desc.data, + &weights_desc.data, + &diff_dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_backward_data_desc_init( + &data, + convert_to_c(aalgorithm), + &diff_src_desc.data, + &weights_desc.data, + &diff_dst_desc.data, + &strides[0], + &dilates[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + }; + struct primitive_desc : public handle { + primitive_desc( + const desc &adesc, + const engine &aengine, + const convolution_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a convolution backward data primitive descriptor"); + reset(result); + } + memory::primitive_desc diff_src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_src primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + convolution_backward_data(const primitive_desc &aprimitive_desc, + const primitive::at &diff_dst, + const primitive::at &weights, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {diff_dst.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a convolution backward data primitive"); + reset(result); + } +}; + +struct convolution_backward_weights : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_convolution_backward_weights_desc_init( + &data, + convert_to_c(aalgorithm), + &src_desc.data, + &diff_weights_desc.data, + &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_convolution_backward_weights_desc_init( + &data, + convert_to_c(aalgorithm), + &src_desc.data, + &diff_weights_desc.data, + nullptr, + &diff_dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_backward_weights_desc_init( + &data, + convert_to_c(aalgorithm), + &src_desc.data, + &diff_weights_desc.data, + &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], + &dilates[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_backward_weights_desc_init( + &data, + convert_to_c(aalgorithm), + &src_desc.data, + &diff_weights_desc.data, + nullptr, + &diff_dst_desc.data, + &strides[0], + &dilates[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc( + const desc &adesc, + const engine &aengine, + const convolution_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a convolution backward weights primitive " + "descriptor"); + reset(result); + } + memory::primitive_desc src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + convolution_backward_weights(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &diff_dst, + const memory &diff_weights, + const memory &diff_bias) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_weights.get(), diff_bias.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a convolution backward weights primitive"); + reset(result); + } + convolution_backward_weights(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &diff_dst, + const memory &diff_weights) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_weights.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a convolution backward weights primitive"); + reset(result); + } +}; + +struct convolution_relu_forward : public primitive { + struct desc { + mkldnn_convolution_relu_desc_t data; + desc(const convolution_forward::desc conv_desc, + const float negative_slope) { + error::wrap_c_api( + mkldnn_convolution_relu_desc_init( + &data, &conv_desc.data, negative_slope), + "could not create a convolution_relu_forward descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a convolution relu forward descriptor"); + reset(result); + } + + engine get_engine() { return engine::query(*this); } + }; + + convolution_relu_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const primitive::at &bias, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a convolution relu forward primitive"); + reset(result); + } + + convolution_relu_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a convolution relu forward primitive"); + reset(result); + } +}; + +/// @} +// +/// @addtogroup cpp_api_deconvolution Deconvolution +/// @{ + +struct deconvolution_forward : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + &weights_desc.data, + &bias_desc.data, + &dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + &weights_desc.data, + nullptr, + &dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + }; + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a deconvolution forward primitive descriptor"); + reset(result); + } + + primitive_desc(const desc &adesc, + const primitive_attr &aattr, + const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create_v2( + &result, &adesc.data, aattr.get(), aengine.get(), nullptr), + "could not create a deconvolution forward primitive descriptor"); + reset(result); + } + + memory::primitive_desc src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + deconvolution_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const primitive::at &bias, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a deconvolution forward bias primitive"); + reset(result); + } + + deconvolution_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a deconvolution forward primitive"); + reset(result); + } +}; + +struct deconvolution_backward_data : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_deconvolution_backward_data_desc_init( + &data, + convert_to_c(aalgorithm), + &diff_src_desc.data, + &weights_desc.data, + &diff_dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward data descriptor"); + } + }; + struct primitive_desc : public handle { + primitive_desc( + const desc &adesc, + const engine &aengine, + const deconvolution_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a deconvolution backward data primitive " + "descriptor"); + reset(result); + } + memory::primitive_desc diff_src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_src primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + deconvolution_backward_data(const primitive_desc &aprimitive_desc, + const primitive::at &diff_dst, + const primitive::at &weights, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {diff_dst.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a deconvolution backward data primitive"); + reset(result); + } +}; + +struct deconvolution_backward_weights : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_deconvolution_backward_weights_desc_init( + &data, + convert_to_c(aalgorithm), + &src_desc.data, + &diff_weights_desc.data, + &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_deconvolution_backward_weights_desc_init( + &data, + convert_to_c(aalgorithm), + &src_desc.data, + &diff_weights_desc.data, + nullptr, + &diff_dst_desc.data, + &strides[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc( + const desc &adesc, + const engine &aengine, + const deconvolution_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a deconvolution backward weights primitive " + "descriptor"); + reset(result); + } + memory::primitive_desc src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + deconvolution_backward_weights(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &diff_dst, + const memory &diff_weights, + const memory &diff_bias) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_weights.get(), diff_bias.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a deconvolution backward weights primitive"); + reset(result); + } + deconvolution_backward_weights(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &diff_dst, + const memory &diff_weights) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_weights.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a deconvolution backward weights primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_lrn LRN +/// @{ + +struct lrn_forward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + int local_size, + float alpha, + float beta, + float k) { + error::wrap_c_api( + mkldnn_lrn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + local_size, + alpha, + beta, + k), + "could not create a lrn forward descriptor"); + } + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + int local_size, + float alpha, + float beta) { + error::wrap_c_api( + mkldnn_lrn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + local_size, + alpha, + beta, + float(1.0)), + "could not create a lrn forward descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a lrn forward primitive descriptor"); + reset(result); + } + + memory::primitive_desc src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t ldesc; + const_mkldnn_primitive_desc_t const_ldesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), + "could not clone a workspace primitive descriptor"); + adesc.reset(ldesc); + return adesc; + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + lrn_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &workspace, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data}; + const_mkldnn_primitive_t outputs[] = {dst.get(), workspace.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a lrn forward primitive"); + reset(result); + } + + lrn_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a lrn forward primitive"); + reset(result); + } +}; + +struct lrn_backward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &data_desc, + const memory::desc &diff_data_desc, + int local_size, + float alpha, + float beta, + float k) { + error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, + convert_to_c(aalgorithm), + &diff_data_desc.data, + &data_desc.data, + local_size, + alpha, + beta, + k), + "could not create a lrn backward descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &data_desc, + const memory::desc &diff_data_desc, + int local_size, + float alpha, + float beta) { + error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, + convert_to_c(aalgorithm), + &diff_data_desc.data, + &data_desc.data, + local_size, + alpha, + beta, + float(1.0)), + "could not create a lrn backward descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, + const engine &aengine, + const lrn_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a backward lrn primitive descriptor"); + reset(result); + } + + memory::primitive_desc diff_src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_src primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t ldesc; + const_mkldnn_primitive_desc_t const_ldesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), + "could not clone a workspace primitive descriptor"); + adesc.reset(ldesc); + return adesc; + } + + memory::primitive_desc diff_dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff_dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + lrn_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &diff_dst, + const primitive::at &workspace, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data, workspace.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a lrn backward primitive"); + reset(result); + } + + lrn_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &diff_dst, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a lrn backward primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_pooling Pooling +/// @{ + +struct pooling_forward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(prop_kind aprop_kind, + algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims kernel, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_pooling_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, + &dst_desc.data, + &strides[0], + &kernel[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a forward pooling descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a forward pooling primitive descriptor"); + reset(result); + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a workspace primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + pooling_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data}; + const_mkldnn_primitive_t outputs[] = {dst.get(), nullptr}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a pooling forward primitive"); + reset(result); + } + + pooling_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &dst, + const memory &workspace) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data}; + const_mkldnn_primitive_t outputs[] = {dst.get(), workspace.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a pooling forward primitive"); + reset(result); + } +}; + +struct pooling_backward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::dims &strides, + const memory::dims &kernel, + const memory::dims &padding_l, + const memory::dims &padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_backward_desc_init( + &data, + convert_to_c(aalgorithm), + &diff_src_desc.data, + &diff_dst_desc.data, + &strides[0], + &kernel[0], + &padding_l[0], + &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a backward pooling descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc( + const desc &adesc, + const engine &aengine, + const pooling_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a backward pooling primitive descriptor"); + reset(result); + } + + memory::primitive_desc diff_src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff src primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + pooling_backward(const primitive_desc &aprimitive_desc, + const primitive::at &diff_dst, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a pooling backward primitive"); + reset(result); + } + + pooling_backward(const primitive_desc &aprimitive_desc, + const primitive::at &diff_dst, + const primitive::at &workspace, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {diff_dst.data, workspace.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a pooling backward primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_eltwise Eltwise +/// @{ + +struct eltwise_forward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + template + desc(prop_kind aprop_kind, + algorithm alg_kind, + const memory::desc &src_desc, + T alpha = 0, + T beta = 0) { + error::wrap_c_api( + mkldnn_eltwise_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + mkldnn::convert_to_c(alg_kind), + &src_desc.data, + static_cast(alpha), + static_cast(beta)), + "could not create a eltwise forward descriptor"); + } + + /** @deprecated: api backward compatibility for relu */ + template + MKLDNN_DEPRECATED desc(prop_kind aprop_kind, + const memory::desc &src_desc, + T negative_slope) + : desc(aprop_kind, eltwise_relu, src_desc, negative_slope) {} + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a eltwise forward primitive descriptor"); + reset(result); + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + eltwise_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a eltwise forward primitive"); + reset(result); + } +}; + +typedef eltwise_forward relu_forward; + +struct eltwise_backward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + + template + desc(algorithm alg_kind, + const memory::desc &diff_data_desc, + const memory::desc &data_desc, + T alpha = 0, + T beta = 0) { + error::wrap_c_api( + mkldnn_eltwise_backward_desc_init(&data, + mkldnn::convert_to_c(alg_kind), + &diff_data_desc.data, + &data_desc.data, + static_cast(alpha), + static_cast(beta)), + "could not create a eltwise backward descriptor"); + } + + /** @deprecated: api backward compatibility for relu */ + template + MKLDNN_DEPRECATED desc(const memory::desc &diff_data_desc, + const memory::desc &data_desc, + T negative_slope) + : desc(eltwise_relu, diff_data_desc, data_desc, negative_slope) {} + }; + + struct primitive_desc : public handle { + primitive_desc( + const desc &adesc, + const engine &aengine, + const eltwise_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a eltwise backward primitive descriptor"); + reset(result); + } + + memory::primitive_desc diff_src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff src primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + eltwise_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &diff_dst, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a eltwise backward primitive"); + reset(result); + } +}; + +typedef eltwise_backward relu_backward; + +/// @} + +/// @addtogroup cpp_api_softmax Softmax +/// @{ + +struct softmax_forward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(prop_kind aprop_kind, + const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api( + mkldnn_softmax_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + &data_desc.data, + softmax_axis), + "could not create a softmax forward descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a softmax forward primitive descriptor"); + reset(result); + } + + engine get_engine() { return engine::query(*this); } + }; + + softmax_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a softmax forward primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_batch_norm Batch normalization +/// @{ + +struct batch_normalization_forward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template + desc(prop_kind aprop_kind, + const memory::desc &src_desc, + T epsilon, + unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + &src_desc.data, + static_cast(epsilon), + flags), + "could not create a batch normalization forward descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a batch normalization forward " + "primitive descriptor"); + reset(result); + } + + primitive_desc(const desc &adesc, + const primitive_attr &aattr, + const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create_v2( + &result, &adesc.data, aattr.get(), aengine.get(), nullptr), + "could not create a batch normalization forward " + "primitive descriptor"); + reset(result); + } + + memory::primitive_desc weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t bndesc; + const_mkldnn_primitive_desc_t const_bndesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), + "could not clone a weights primitive descriptor"); + adesc.reset(bndesc); + return adesc; + } + + memory::primitive_desc mean_primitive_desc() const { + memory::primitive_desc aprimitive_desc; + mkldnn_primitive_desc_t bndesc; + mkldnn_batch_normalization_desc_t *p; + error::wrap_c_api( + mkldnn_primitive_desc_query( + get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), + "could not get a batch-normalization descriptor"); + const_mkldnn_primitive_desc_t const_bndesc = + (p->flags & use_global_stats) + ? mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 1) + : mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), + "could not clone a mean primitive descriptor"); + aprimitive_desc.reset(bndesc); + return aprimitive_desc; + } + + memory::primitive_desc variance_primitive_desc() const { + memory::primitive_desc aprimitive_desc; + mkldnn_primitive_desc_t bndesc; + mkldnn_batch_normalization_desc_t *p; + error::wrap_c_api( + mkldnn_primitive_desc_query( + get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), + "could not get a batch-normalization descriptor"); + const_mkldnn_primitive_desc_t const_bndesc = + (p->flags & use_global_stats) + ? mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 2) + : mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 2); + error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), + "could not clone a variance primitive descriptor"); + aprimitive_desc.reset(bndesc); + return aprimitive_desc; + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a workspace primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + batch_normalization_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &mean, + const primitive::at &variance, + const primitive::at &weights, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = { + src.data, mean.data, variance.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization forward primitive"); + reset(result); + } + + batch_normalization_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &mean, + const primitive::at &variance, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, mean.data, variance.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization forward primitive"); + reset(result); + } + + /// @warning batch_normalization_forward has 2 constructors with very + /// similar signatures: + /// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out + /// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out + /// The only way to distinguish between those is to explicitly + /// cast all input parameters to their type, i.e. to + /// const primitive:at &. + batch_normalization_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const memory &dst, + const memory &mean, + const memory &variance) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; + const_mkldnn_primitive_t outputs[] = { + dst.get(), mean.get(), variance.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization forward primitive"); + reset(result); + } + + batch_normalization_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const memory &dst, + const memory &mean, + const memory &variance, + const memory &workspace) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; + const_mkldnn_primitive_t outputs[] = { + dst.get(), mean.get(), variance.get(), workspace.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization forward primitive"); + reset(result); + } + + batch_normalization_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &dst, + const memory &mean, + const memory &variance) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data}; + const_mkldnn_primitive_t outputs[] = { + dst.get(), mean.get(), variance.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization forward primitive"); + reset(result); + } + + /// @warning batch_normalization_forward has 2 constructors with very + /// similar signatures: + /// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out + /// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out + /// The only way to distinguish between those is to explicitly + /// cast all input parameters to their type, i.e. to + /// const primitive:at &. + /// @note to make users' experience a little bit better this constructor + /// checks if whether parameters match corresponding primitive + /// descriptor, and if they are not -- call the other (proper) + /// constructor. Yeah, this is still very ugly... + batch_normalization_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &dst, + const memory &mean, + const memory &variance, + const memory &workspace) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[2] = {src.data}; + const_mkldnn_primitive_t outputs[4] = { + dst.get(), mean.get(), variance.get(), workspace.get()}; + + if (1) { // check whether this is the `wrong` constructor + const int n_inputs_expected = mkldnn_primitive_desc_query_s32( + aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0); + const int n_outputs_expected = mkldnn_primitive_desc_query_s32( + aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0); + if (n_inputs_expected == 2 && n_outputs_expected == 3) { + // shift parameters, get rid of workspace, and add weights... + auto _weights = dst; + inputs[1] = {_weights.get(), 0}; + + auto _dst = mean, _mean = variance, _variance = workspace; + outputs[0] = _dst.get(); + outputs[1] = _mean.get(); + outputs[2] = _variance.get(); + outputs[3] = nullptr; + } + } + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization forward primitive"); + reset(result); + } + + batch_normalization_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &weights, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization forward primitive"); + reset(result); + } + + batch_normalization_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization forward primitive"); + reset(result); + } +}; + +struct batch_normalization_backward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template + desc(prop_kind aprop_kind, + const memory::desc &diff_data_desc, + const memory::desc &data_desc, + T epsilon, + unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_backward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + &diff_data_desc.data, + &data_desc.data, + static_cast(epsilon), + flags), + "could not create a batch normalization backward descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, + const engine &aengine, + const batch_normalization_forward::primitive_desc + &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a batch normalization backward primitive " + "descriptor"); + reset(result); + } + + memory::primitive_desc weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t bndesc; + const_mkldnn_primitive_desc_t const_bndesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), + "could not clone a weights primitive descriptor"); + adesc.reset(bndesc); + return adesc; + } + + memory::primitive_desc diff_weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t bndesc; + const_mkldnn_primitive_desc_t const_bndesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), + "could not clone a diff_weights primitive descriptor"); + adesc.reset(bndesc); + return adesc; + } + + memory::primitive_desc mean_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t bndesc; + const_mkldnn_primitive_desc_t const_bndesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), + "could not clone a mean primitive descriptor"); + adesc.reset(bndesc); + return adesc; + } + + memory::primitive_desc variance_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t bndesc; + const_mkldnn_primitive_desc_t const_bndesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 2); + error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), + "could not clone a variance primitive descriptor"); + adesc.reset(bndesc); + return adesc; + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a workspace primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + // Prop_kind == backward + batch_normalization_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &mean, + const primitive::at &variance, + const primitive::at &diff_dst, + const primitive::at &weights, + const memory &diff_src, + const memory &diff_weights) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = { + src.data, mean.data, variance.data, diff_dst.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get(), diff_weights.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization backward primitive"); + reset(result); + } + + // Prop_kind == backward (+ws) + batch_normalization_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &mean, + const primitive::at &variance, + const primitive::at &diff_dst, + const primitive::at &weights, + const primitive::at &workspace, + const memory &diff_src, + const memory &diff_weights) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, + mean.data, + variance.data, + diff_dst.data, + weights.data, + workspace.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get(), diff_weights.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization backward primitive"); + reset(result); + } + + // Prop_kind == backward_data (+ws or +weights) + /// @warning This constructor works for backward_data propagation + /// - w/ weights but w/o workspace, or + /// - w/ workspace but w/o weights + batch_normalization_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &mean, + const primitive::at &variance, + const primitive::at &diff_dst, + const primitive::at &weights_or_workspace, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, + mean.data, + variance.data, + diff_dst.data, + weights_or_workspace.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization backward primitive"); + reset(result); + } + + // Prop_kind == backward_data + batch_normalization_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at &mean, + const primitive::at &variance, + const primitive::at &diff_dst, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = { + src.data, mean.data, variance.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a batch normalization backward primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_inner_product Inner Product +/// @{ + +struct inner_product_forward : public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(prop_kind aprop_kind, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc) { + error::wrap_c_api(mkldnn_inner_product_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + &src_desc.data, + &weights_desc.data, + &bias_desc.data, + &dst_desc.data), + "could not create a inner product forward descriptor"); + } + + desc(prop_kind aprop_kind, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc) { + error::wrap_c_api(mkldnn_inner_product_forward_desc_init( + &data, + mkldnn::convert_to_c(aprop_kind), + &src_desc.data, + &weights_desc.data, + nullptr, + &dst_desc.data), + "could not create a inner product forward descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create a inner product forward primitive descriptor"); + reset(result); + } + + primitive_desc(const desc &adesc, + const primitive_attr &aattr, + const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create_v2( + &result, &adesc.data, aattr.get(), aengine.get(), nullptr), + "could not create a inner product " + "forward primitive descriptor"); + reset(result); + } + + memory::primitive_desc src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + inner_product_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at weights, + const primitive::at &bias, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a inner product forward primitive"); + reset(result); + } + + inner_product_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at weights, + const memory &dst) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {dst.get()}; + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a inner product forward primitive"); + reset(result); + } +}; + +struct inner_product_backward_data : public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_data_desc_init(&data, + &diff_src_desc.data, + &weights_desc.data, + &diff_dst_desc.data), + "could not create a inner product backward data descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc( + const desc &adesc, + const engine &aengine, + const inner_product_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a inner product backward data primitive " + "descriptor"); + reset(result); + } + + memory::primitive_desc diff_dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff dst primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff src primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + inner_product_backward_data(const primitive_desc &aprimitive_desc, + const primitive::at &diff_dst, + const primitive::at weights, + const memory &diff_src) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {diff_dst.data, weights.data}; + const_mkldnn_primitive_t outputs[] = {diff_src.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a inner product backward data primitive"); + reset(result); + } +}; + +struct inner_product_backward_weights : public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, + &src_desc.data, + &diff_weights_desc.data, + &diff_bias_desc.data, + &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, + &src_desc.data, + &diff_weights_desc.data, + nullptr, + &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + }; + + struct primitive_desc : public handle { + primitive_desc( + const desc &adesc, + const engine &aengine, + const inner_product_forward::primitive_desc &hint_fwd_primitive_desc) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create(&result, + &adesc.data, + aengine.get(), + hint_fwd_primitive_desc.get()), + "could not create a inner product backward weights primitive " + "descriptor"); + reset(result); + } + + memory::primitive_desc diff_dst_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff dst primititve descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_weights_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a diff bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc src_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + inner_product_backward_weights(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at diff_dst, + const memory &diff_weights) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_weights.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a inner product backward weights primitive"); + reset(result); + } + + inner_product_backward_weights(const primitive_desc &aprimitive_desc, + const primitive::at &src, + const primitive::at diff_dst, + const memory &diff_weights, + const memory &diff_bias) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; + const_mkldnn_primitive_t outputs[] = {diff_weights.get(), diff_bias.get()}; + error::wrap_c_api( + mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create a inner product backward weights primitive"); + reset(result); + } +}; + +/// @} + +/// @addtogroup cpp_api_rnn RNN +/// @{ + +struct rnn_cell { + struct desc { + mkldnn_rnn_cell_desc_t c_rnn_cell_; + + desc(algorithm kind, algorithm activation_f) { + error::wrap_c_api( + mkldnn_rnn_cell_desc_init(&c_rnn_cell_, + mkldnn::convert_to_c(kind), + mkldnn::convert_to_c(activation_f), + 0U, + 0, + 0), + "could not init an rnn cell descriptor"); + } + desc(algorithm kind) : desc(kind, algorithm::algorithm_undef) {} + + operator const mkldnn_rnn_cell_desc_t *() const { return &c_rnn_cell_; } + + algorithm get_cell_kind() const { return algorithm(c_rnn_cell_.cell_kind); } + algorithm get_activation() const { + return algorithm(c_rnn_cell_.activation_kind); + } + + float get_alpha() const { return c_rnn_cell_.alpha; } + void set_alpha(float alpha) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; + c_rnn_cell_.alpha = alpha; + } + + float get_clipping() const { return c_rnn_cell_.clipping; } + void set_clipping(float clipping) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; + c_rnn_cell_.clipping = clipping; + } + + int get_gates_count() const { + return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); + } + int get_state_count() const { + return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); + } + }; +}; + +struct rnn_forward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, + rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc) { + error::wrap_c_api( + mkldnn_rnn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, + &src_iter_desc.data, + &weights_layer_desc.data, + &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, + &dst_iter_desc.data), + "could not create an RNN forward descriptor"); + } + }; + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create an RNN forward primitive descriptor"); + reset(result); + } + + memory::primitive_desc src_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone an src layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc src_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src iter primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_src_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 2); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t ldesc; + const_mkldnn_primitive_desc_t const_ldesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), + "could not clone a workspace primitive descriptor"); + adesc.reset(ldesc); + return adesc; + } + + memory::primitive_desc dst_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api( + mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 1); + error::wrap_c_api( + mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last iteration primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + rnn_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src_layer, + const primitive::at &src_iter, + const primitive::at &weights_layer, + const primitive::at &weights_iter, + const primitive::at &bias, + const memory &dst_layer, + const memory &dst_iter, + const memory &workspace) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[5]; + const_mkldnn_primitive_t outputs[3]; + int idx = 0; + inputs[idx++] = src_layer.data; + if (!is_null_memory(src_iter.data.primitive)) inputs[idx++] = src_iter.data; + inputs[idx++] = weights_layer.data; + inputs[idx++] = weights_iter.data; + if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data; + + idx = 0; + outputs[idx++] = dst_layer.get(); + if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get(); + if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get(); + + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create an RNN forward primitive"); + reset(result); + } +}; + +struct rnn_backward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, + rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc) { + error::wrap_c_api( + mkldnn_rnn_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, + &src_iter_desc.data, + &weights_layer_desc.data, + &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, + &dst_iter_desc.data, + &diff_src_layer_desc.data, + &diff_src_iter_desc.data, + &diff_weights_layer_desc.data, + &diff_weights_iter_desc.data, + &diff_bias_desc.data, + &diff_dst_layer_desc.data, + &diff_dst_iter_desc.data), + "could not create an RNN backward descriptor"); + } + }; + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api( + mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create an RNN backward primitive descriptor"); + reset(result); + } + + memory::primitive_desc src_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone an src layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc src_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(src_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src iter primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(weights_pd), 2); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api( + mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(dst_pd), 1); + error::wrap_c_api( + mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last iteration primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_src_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone an src_layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_src_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_src_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src iter primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_weights_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_weights_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_weights_pd), 2); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_dst_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api( + mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_dst_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(diff_dst_pd), 1); + error::wrap_c_api( + mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last iteration primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t ldesc; + const_mkldnn_primitive_desc_t const_ldesc = + mkldnn_primitive_desc_query_pd( + get(), mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), + "could not clone a workspace primitive descriptor"); + adesc.reset(ldesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + // With last iteration (with and without input src_iter) + rnn_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src_layer, + const primitive::at &src_iter, + const primitive::at &weights_layer, + const primitive::at &weights_iter, + const primitive::at &bias, + const primitive::at &dst_layer, + const primitive::at &dst_iter, + const memory &diff_src_layer, + const memory &diff_src_iter, + const memory &diff_weights_layer, + const memory &diff_weights_iter, + const memory &diff_bias, + const primitive::at &diff_dst_layer, + const primitive::at &diff_dst_iter, + const primitive::at &workspace) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[10]; + const_mkldnn_primitive_t outputs[5]; + int idx = 0; + inputs[idx] = src_layer.data; + if (!is_null_memory(src_iter.data.primitive)) inputs[idx++] = src_iter.data; + inputs[idx++] = weights_layer.data; + inputs[idx++] = weights_iter.data; + if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data; + inputs[idx] = dst_layer.data; + if (!is_null_memory(dst_iter.data.primitive)) inputs[idx++] = dst_iter.data; + inputs[idx] = diff_dst_layer.data; + if (!is_null_memory(diff_dst_iter.data.primitive)) + inputs[idx++] = diff_dst_iter.data; + inputs[idx] = workspace.data; + + idx = 0; + outputs[idx] = diff_src_layer.get(); + if (!is_null_memory(diff_src_iter.get())) + outputs[idx++] = diff_src_iter.get(); + outputs[idx] = diff_weights_layer.get(); + outputs[idx] = diff_weights_iter.get(); + if (!is_null_memory(diff_bias.get())) outputs[idx] = diff_bias.get(); + error::wrap_c_api(mkldnn_primitive_create( + &result, aprimitive_desc.get(), inputs, outputs), + "could not create an RNN backward primitive"); + reset(result); + } +}; + +/// @} +/// @} Primitives + +/// @addtogroup cpp_api_stream Stream +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> +struct handle_traits { + static constexpr auto destructor = &mkldnn_stream_destroy; +}; +#endif + +struct stream : public handle { + using handle::handle; + + enum kind { + any = mkldnn_stream_kind_t::mkldnn_any_stream, + eager = mkldnn_stream_kind_t::mkldnn_eager, + lazy = mkldnn_stream_kind_t::mkldnn_lazy + }; + + static mkldnn_stream_kind_t convert_to_c(kind akind) { + return static_cast(akind); + } + /// Constructs a stream. + stream(kind akind) { + mkldnn_stream_t astream; + error::wrap_c_api(mkldnn_stream_create(&astream, convert_to_c(akind)), + "could not create a stream"); + reset(astream); + } + + /// Submits a vector of primitives to a stream for computations. + /// + /// @param primitives The vector of primitives to submit. + /// @returns The stream. + stream &submit(std::vector primitives) { + // TODO: find a proper way to convert vector to + // vector + if (primitives.size() == 0) return *this; + std::vector c_api_primitives; + c_api_primitives.reserve(primitives.size()); + auto convert_to_c = [](primitive p) { return p.get(); }; + std::transform(primitives.begin(), + primitives.end(), + std::back_inserter(c_api_primitives), + convert_to_c); + + mkldnn_primitive_t c_api_error_primitive; + error::wrap_c_api(mkldnn_stream_submit(get(), + c_api_primitives.size(), + &c_api_primitives[0], + &c_api_error_primitive), + "could not submit primitives to a stream", + &c_api_error_primitive); + + return *this; + } + + /// Waits for all computations submitted to the stream to complete. + /// + /// @param block Specifies whether the operation should wait indefinitely or + /// return + /// immediately. + /// @returns @c true if all computations completed. + /// @returns @c false if not all computations completed. + bool wait(bool block = true) { + mkldnn_primitive_t c_api_error_primitive; + mkldnn_status_t status = + mkldnn_stream_wait(get(), block, &c_api_error_primitive); + if (status != mkldnn_success && status != mkldnn_try_again) + error::wrap_c_api( + status, "could not wait on a stream", &c_api_error_primitive); + return (status == mkldnn_success); + } + + stream &rerun() { + mkldnn_primitive_t c_api_error_primitive; + error::wrap_c_api(mkldnn_stream_rerun(get(), &c_api_error_primitive), + "could not rerun a stream", + &c_api_error_primitive); + return *this; + } +}; + +/// @} + +/// @} C++ API + +} // namespace mkldnn + +#endif diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 7af6ed1463ab737e871da487f2a687301652ef2d..32b1b65bd97ef1e512a5880843509611b606f52d 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -480,6 +480,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, program.current_block_idx = current_block_idx program.sync_with_cpp() + # FIXME(zcd): prevent loss.grad optimized by mem_opt. + loss.block.var(_append_grad_suffix_(loss.name)).persistable = True if parameter_list is not None: parameters = parameter_list diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 28e54f5492e7b04a1406e319cecf977d4a55725e..38c765938fe9d7b2103bfdd926874c485d0ff4dc 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -489,7 +489,7 @@ class Operator(object): 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'ncclInit', 'channel_create', 'channel_close', - 'channel_send', 'channel_recv', 'select' + 'channel_send', 'channel_recv', 'select', 'gen_nccl_id' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 5b43f860e7075745bbf6e76c2f9d0e9a87a86db0..7358c4b60e87893b9c04e3da2221dfb69d3ba0c7 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -31,7 +31,9 @@ class ParallelExecutor(object): allow_op_delay=False, share_vars_from=None, use_default_grad_scale=True, - balance_parameter_opt_between_cards=False): + balance_parameter_opt_between_cards=False, + num_trainers=1, + trainer_id=0): """ ParallelExecutor can run program in parallel. @@ -55,6 +57,11 @@ class ParallelExecutor(object): balance_parameter_opt_between_cards(bool, default True): Whether updating different gradients on different cards. Currently, it is not recommended. + num_trainers(int, default 1): If greater than 1, NCCL will be + initialized with multpile rank of nodes, each node should have + same number of GPUs. Distributed training will be enabled then. + trainer_id(int, default 0): Must use together with num_trainers. + trainer_id is the "rank" of current node starts from 0. Returns: A ParallelExecutor object. @@ -134,8 +141,9 @@ class ParallelExecutor(object): local_scopes, allow_op_delay, use_default_grad_scale, - balance_parameter_opt_between_cards) - + balance_parameter_opt_between_cards, + num_trainers, + trainer_id) self.scope = scope def run(self, fetch_list, feed=None, feed_dict=None): diff --git a/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt index 9ab00325a2eef3bbc79757ad1a3e6f8511c49552..c2a15bdb3b17b65fe861dd429f548074c13e2f09 100644 --- a/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt +++ b/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt @@ -6,4 +6,5 @@ foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() +add_subdirectory(fit_a_line) add_subdirectory(recognize_digits) diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..673c965b662a022739f8d489c331f4de9455a926 --- /dev/null +++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +# default test +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py new file mode 100644 index 0000000000000000000000000000000000000000..8c9bbb52d769282460c571ebc51d5eff18de3114 --- /dev/null +++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py @@ -0,0 +1,137 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import contextlib +import numpy +import unittest + +# train reader +BATCH_SIZE = 20 + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=BATCH_SIZE) + +test_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.test(), buf_size=500), + batch_size=BATCH_SIZE) + + +def inference_program(): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + return y_predict + + +def linear(): + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + y_predict = inference_program() + + loss = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_loss = fluid.layers.mean(loss) + + return avg_loss + + +def train(use_cuda, save_dirname): + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + + trainer = fluid.Trainer( + train_func=linear, + infer_func=inference_program, + place=place, + optimizer=fluid.optimizer.SGD(learning_rate=0.001)) + + def event_handler(event): + if isinstance(event, fluid.EndEpochEvent): + test_metrics = trainer.test( + reader=test_reader, feed_order=['x', 'y']) + print test_metrics + ''' + + ... + ['25.768919467926025'] + ['15.343549569447836'] + ... + + ''' + if float(test_metrics[0]) < 20.0: + if save_dirname is not None: + # NOT clear yet + # fluid.io.save_inference_model(save_dirname, ['x'], [y_predict]) + # trainer.save_params(save_dirname) + # https://github.com/PaddlePaddle/Paddle/pull/10445 + trainer.save_inference_model(save_dirname) + return + + trainer.train( + reader=train_reader, + num_epochs=100, + event_handler=event_handler, + feed_order=['x', 'y']) + + +# infer +def infer(use_cuda, save_dirname=None): + if save_dirname is None: + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + inferencer = fluid.Inferencer(param_path=save_dirname, place=place) + + batch_size = 10 + tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32") + + results = inferencer.infer({'x': tensor_x}) + print("infer results: ", results[0]) + + +def main(use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + + # Directory for saving the trained model + save_dirname = "fit_a_line.inference.model" + + train(use_cuda, save_dirname) + infer(use_cuda, save_dirname) + + +class TestFitALine(unittest.TestCase): + def test_cpu(self): + with self.program_scope_guard(): + with fluid.unique_name.guard(): + main(use_cuda=False) + + def test_cuda(self): + with self.program_scope_guard(): + with fluid.unique_name.guard(): + main(use_cuda=True) + + @contextlib.contextmanager + def program_scope_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/book/notest_understand_sentiment.py b/python/paddle/fluid/tests/book/notest_understand_sentiment.py index 241778e303036d068dc0a40e4574a02eb97ad134..792ed7368d646cd9dff9255eb402b6a9b84f69a6 100644 --- a/python/paddle/fluid/tests/book/notest_understand_sentiment.py +++ b/python/paddle/fluid/tests/book/notest_understand_sentiment.py @@ -170,7 +170,7 @@ def train(word_dict, assert save_dirname is None adagrad = fluid.optimizer.Adagrad(learning_rate=0.002) - optimize_ops, params_grads = adagrad.minimize(cost) + adagrad.minimize(cost) train_data = paddle.batch( paddle.reader.shuffle( diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index ecb34699af0dc14782601702ab8afedbca7e1bfd..b1a6b524d33cae97c8982ffb8f780b1b07761a09 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -33,7 +33,7 @@ def train(use_cuda, save_dirname, is_local): avg_cost = fluid.layers.mean(cost) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) + sgd_optimizer.minimize(avg_cost) BATCH_SIZE = 20 diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index dbcdb5766e7d20efdb12da0ea4c6f005d903849b..0f3a4c9242a81a3c1fb90268245715a8e59a207a 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -125,7 +125,7 @@ def train(net_type, use_cuda, save_dirname, is_local): test_program = fluid.default_main_program().clone(for_test=True) optimizer = fluid.optimizer.Adam(learning_rate=0.001) - optimize_ops, params_grads = optimizer.minimize(avg_cost) + optimizer.minimize(avg_cost) BATCH_SIZE = 128 PASS_NUM = 1 diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py index 0faba33032d5dfc0b751a5191e7b2ae0c1f172bf..09793760e5504c04ad4b0bfac5c5d7b7047cf85d 100644 --- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py @@ -175,7 +175,7 @@ def train(use_cuda, save_dirname=None, is_local=True): decay_steps=100000, decay_rate=0.5, staircase=True)) - optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) + sgd_optimizer.minimize(avg_cost) # TODO(qiao) # add dependency track and move this config before optimizer diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py index 46c6b9c29a265741a99655d5ac29244798f6fec2..e8a75f473f62df528b7f39bf5f9085076e005c25 100644 --- a/python/paddle/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/test_machine_translation.py @@ -185,7 +185,7 @@ def train_main(use_cuda, is_sparse, is_local=True): learning_rate=1e-4, regularization=fluid.regularizer.L2DecayRegularizer( regularization_coeff=0.1)) - optimize_ops, params_grads = optimizer.minimize(avg_cost) + optimizer.minimize(avg_cost) train_data = paddle.batch( paddle.reader.shuffle( diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index c115aa4d7d6b514f9207543730e5e76cb0d2040c..578b1162fbd7e3a1b1c0cc934406818f2e07e019 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -95,7 +95,7 @@ def train(nn_type, test_program = fluid.default_main_program().clone(for_test=True) optimizer = fluid.optimizer.Adam(learning_rate=0.001) - optimize_ops, params_grads = optimizer.minimize(avg_loss) + optimizer.minimize(avg_loss) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() diff --git a/python/paddle/fluid/tests/book/test_recommender_system.py b/python/paddle/fluid/tests/book/test_recommender_system.py index d022dedbff805d597b68b5a47f7931f2dd946615..7be924f762ddeb045dda890dbfdcd96a65449553 100644 --- a/python/paddle/fluid/tests/book/test_recommender_system.py +++ b/python/paddle/fluid/tests/book/test_recommender_system.py @@ -160,7 +160,7 @@ def train(use_cuda, save_dirname, is_local=True): test_program = fluid.default_main_program().clone(for_test=True) sgd_optimizer = SGDOptimizer(learning_rate=0.2) - optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) + sgd_optimizer.minimize(avg_cost) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py index 6dec0f6857e86b4b9c1c67af934aa9bfdb1c3df7..30e1a5040cc92b02bbbf90dac97001812ec90134 100644 --- a/python/paddle/fluid/tests/book/test_word2vec.py +++ b/python/paddle/fluid/tests/book/test_word2vec.py @@ -101,7 +101,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True): avg_cost = fluid.layers.mean(pd()) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) + sgd_optimizer.minimize(avg_cost) train_reader = paddle.batch( paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py new file mode 100644 index 0000000000000000000000000000000000000000..10f8c4f3f0167632bb4a3d454ab026ba73a8f305 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -0,0 +1,113 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.layers as layers +from paddle.fluid.transpiler.distribute_transpiler import delete_ops +import numpy + + +class TestDistTranspiler(unittest.TestCase): + def setUp(self): + self.trainer_id = 0 + self.trainers = 2 + self.pservers = 2 + self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175" + self.current_pserver_ep = "127.0.0.1:6174" + + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + + y_predict = fluid.layers.fc(input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr(name='fc_w')) + + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) + + optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) + return optimize_ops, params_grads + + def test_transpiler(self): + trainer = self.get_trainer() + pserver, startup = self.get_pserver(self.current_pserver_ep) + + self.assertEqual([op.type for op in trainer.global_block().ops], + self.get_expect_trainer_ops()) + + self.assertEqual(len(pserver.blocks), 3) + # block0: listen_and_serv + self.assertEqual([op.type for op in pserver.blocks[0].ops], + ["listen_and_serv"]) + # block2: optimize pass + self.assertEqual([op.type for op in pserver.blocks[1].ops], + ["sum", "scale", "sgd"]) + + # confirm startup program + + self.assertEqual([op.type for op in startup.global_block().ops], [ + "fill_constant", "fill_constant", "uniform_random", "uniform_random" + ]) + + # the variable #fc_w will be split into two blocks + fc_w_var = startup.global_block().var("fc_w.block1") + self.assertEqual(fc_w_var.shape, (500, 1000)) + + def get_main_program(self): + main = fluid.Program() + + with fluid.program_guard(main): + self.net_conf() + + return main + + def get_expect_trainer_ops(self): + trainer = fluid.Program() + + with fluid.program_guard(trainer): + optimize_ops, params_grads = self.net_conf() + + delete_ops(trainer.global_block(), optimize_ops) + return [op.type for op in trainer.global_block().ops + ] + ["split_byref", "send", "concat"] + + def get_trainer(self): + return self._transpiler_instance().get_trainer_program() + + def get_pserver(self, ep): + t = self._transpiler_instance() + pserver = t.get_pserver_program(ep) + startup = t.get_startup_program(ep, pserver) + return pserver, startup + + def _transpiler_instance(self): + main = self.get_main_program() + t = fluid.DistributeTranspiler() + t.transpile( + self.trainer_id, + program=main, + pservers=self.pserver_eps, + trainers=self.trainers) + return t + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_network_with_dtype.py b/python/paddle/fluid/tests/unittests/test_network_with_dtype.py index af487919a9986f0c45651e8825b8cc38231c1904..fe8aceb3ae42f73590bffe2a372c771654a372a9 100644 --- a/python/paddle/fluid/tests/unittests/test_network_with_dtype.py +++ b/python/paddle/fluid/tests/unittests/test_network_with_dtype.py @@ -27,12 +27,15 @@ class TestNetWithDtype(unittest.TestCase): def set_network(self): self.dtype = "float64" self.init_dtype() - self.x = fluid.layers.data(name='x', shape=[13], dtype=self.dtype) - self.y = fluid.layers.data(name='y', shape=[1], dtype=self.dtype) - y_predict = fluid.layers.fc(input=self.x, size=1, act=None) + main = fluid.Program() + with fluid.program_guard(main): + self.x = fluid.layers.data(name='x', shape=[13], dtype=self.dtype) + self.y = fluid.layers.data(name='y', shape=[1], dtype=self.dtype) + y_predict = fluid.layers.fc(input=self.x, size=1, act=None) - cost = fluid.layers.square_error_cost(input=y_predict, label=self.y) - avg_cost = fluid.layers.mean(cost) + cost = fluid.layers.square_error_cost(input=y_predict, label=self.y) + avg_cost = fluid.layers.mean(cost) + self.program = main self.fetch_list = [avg_cost] sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) @@ -45,7 +48,7 @@ class TestNetWithDtype(unittest.TestCase): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) for data in train_reader(): - exe.run(fluid.default_main_program(), + exe.run(self.program, feed=feeder.feed(data), fetch_list=self.fetch_list) # the main program is runable, the datatype is fully supported @@ -68,7 +71,7 @@ class TestNetWithDtype(unittest.TestCase): # TODO(dzhwinter): make sure the fp16 is runable -# class TestFloat16(SimpleNet): +# class TestFloat16(TestNetWithDtype): # def init_dtype(self): # self.dtype = "float16" diff --git a/python/paddle/fluid/transpiler/__init__.py b/python/paddle/fluid/transpiler/__init__.py index 6d3c1b947f4acb1335b25e6eb0099d5d532c895a..413c36c5c41bbe0169f1c050ccdac040202d66df 100644 --- a/python/paddle/fluid/transpiler/__init__.py +++ b/python/paddle/fluid/transpiler/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from distribute_transpiler import DistributeTranspiler from inference_transpiler import InferenceTranspiler from memory_optimization_transpiler import memory_optimize, release_memory diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index b45cb987d896bd189531e97eb62bddbbee16069d..a323f8d03613e7c4149812daab4ccb57fb940a7e 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -17,7 +17,7 @@ from __future__ import print_function import math import distributed_splitter as splitter -from .. import core +from .. import core, framework from ..framework import Program, default_main_program, \ default_startup_program, \ Variable, Parameter, grad_var_name @@ -417,7 +417,7 @@ class DistributeTranspiler: def __append_optimize_op__(op, block, grad_to_block_id): if self._is_opt_op(op): self._append_pserver_ops(block, op, endpoint, grad_to_block_id, - default_main_program()) + self.origin_program) else: self._append_pserver_non_opt_ops(block, op) diff --git a/tools/manylinux1/README.md b/tools/manylinux1/README.md index 898e00bd37c7b7bcbcb4a56476ff10c87381e47a..0e5905040175047f5b79939d97a3efcf38992944 100644 --- a/tools/manylinux1/README.md +++ b/tools/manylinux1/README.md @@ -28,3 +28,38 @@ git clone https://github.com/paddlepaddle/paddle cd paddle/tools/manylinux1 REPO=[yourrepo] ./build_all.sh ``` + +## Build PaddlePaddle for the different Python ABIs + +Choose one of the following Python ABI and set the correct environment variables. + +- cp27-cp27m + + ```bash + export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} + export PATH=/opt/python/cp27-cp27m/bin/:${PATH} + export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python + -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7 + -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so" + ``` + +- cp27-cp27mu + + ```bash + export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:} + export PATH=/opt/python/cp27-cp27mu/bin/:${PATH} + export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python + -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7 + -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so" + ``` + +And then add the `PYTHON_FLAGS` as your cmake flags: + +```bash +cmake .. + ${PYTHON_FLAGS} \ + -DWITH_GPU=OFF \ + ... +``` + +You can find more details about cmake flags at [here](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/build_from_source_en.html#appendix-build-options)