// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /******************************************************************************* * Copyright 2016-2018 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #ifndef MKLDNN_HPP #define MKLDNN_HPP #ifndef DOXYGEN_SHOULD_SKIP_THIS #include #include #include #include #include #include #include "mkldnn.h" #endif namespace mkldnn { /// @addtogroup cpp_api C++ API /// @{ /// @addtogroup cpp_api_utils Utils /// @{ /// A class that provides the destructor for an Intel(R) MKL-DNN C handle template class handle_traits {}; /// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base /// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and /// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class /// can be passed by value. This class enables wrapping: /// - Newly constructed handles. /// @n In this case, the constructed handle uses reference counting provided /// by @p std::shared_ptr with a proper deleter function specified through /// the @p handle_traits class. /// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for /// example, through #mkldnn_primitive_get_output()). /// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a /// deleter because it is assumed that the handle wrapper for the original /// object deletes the handle (this model is similar to @p std::weak_ptr). template > class handle { private: std::shared_ptr::type> _data; handle(const handle &&) = delete; handle &operator=(const handle &&other) = delete; protected: /// Constructs a C handle wrapper. /// @param t The C handle to wrap. /// @param weak A flag to specify whether to construct a weak wrapper. handle(T t = 0, bool weak = false) : _data(0) { reset(t, weak); } bool operator==(const T other) const { return other == _data.get(); } bool operator!=(const T other) const { return !(*this == other); } public: handle(const handle &other) : _data(other._data) {} handle &operator=(const handle &other) { _data = other._data; return *this; } /// Resets the value of a C handle. /// @param t The new value of the C handle. /// @param weak A flag to specify whether the wrapper should be weak. void reset(T t, bool weak = false) { auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); }; _data.reset(t, weak ? dummy_destructor : traits::destructor); } /// Returns the value of the underlying C handle. T get() const { return _data.get(); } bool operator==(const handle &other) const { return other._data.get() == _data.get(); } bool operator!=(const handle &other) const { return !(*this == other); } }; #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits { static constexpr auto destructor = &mkldnn_primitive_desc_destroy; }; template <> struct handle_traits { static constexpr auto destructor = &mkldnn_primitive_destroy; }; #endif /// Base class for all computational primitives. class primitive : public handle { friend struct error; friend struct stream; friend class primitive_at; using handle::handle; public: /// A proxy to C primitive kind enum enum class kind { undefined_primitive = mkldnn_undefined_primitive, memory = mkldnn_memory, view = mkldnn_view, reorder = mkldnn_reorder, concat = mkldnn_concat, concat_inplace = mkldnn_concat_inplace, sum = mkldnn_sum, convolution = mkldnn_convolution, deconvolution = mkldnn_deconvolution, eltwise = mkldnn_eltwise, relu = mkldnn_relu, softmax = mkldnn_softmax, pooling = mkldnn_pooling, lrn = mkldnn_lrn, batch_normalization = mkldnn_batch_normalization, inner_product = mkldnn_inner_product, convolution_relu = mkldnn_convolution_relu, rnn = mkldnn_rnn, }; /// A wrapper structure to specify a particular output of a primitive. struct at { /// The underlying C API structure. mkldnn_primitive_at_t data; /// Constructs a wrapper specifying @p aprimitive output with index @p /// at. /// /// @param aprimitive The target primitive. /// @param at The output index. at(const primitive &aprimitive, size_t at = 0) : data(mkldnn_primitive_at(aprimitive.get(), at)) {} /// Returns the specified output. inline operator primitive() const; }; /// Returns the descriptor of the underlying C API primitive inline const_mkldnn_primitive_desc_t get_primitive_desc() const; // TODO: use the C++ API wrapper structure. }; inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { return static_cast(akind); } /// Intel(R) MKL-DNN exception class. /// /// This class captures the status returned by the failed C API function, error /// message, and, optionally, handle of the primitive that caused the error. struct error : public std::exception { mkldnn_status_t status; std::string message; primitive error_primitive; /// Constructs an error instance. /// /// @param astatus The error status returned by the C API. /// @param amessage The error message. /// @param aerror_primitive (optional) A C handle of the primitive that /// caused the error. error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive = 0) : status(astatus), message(amessage), error_primitive(aerror_primitive, true) {} /// A convenience function for wrapping calls to the C API. Checks the /// return status and throws an #error in case of failure. /// /// @param status The error status returned by the C API. /// @param message The error message. /// @param error_primitive (optional) A C handle of the primitive that /// caused the error. static void wrap_c_api(mkldnn_status_t status, std::string message, mkldnn_primitive_t *error_primitive = 0) { if (status != mkldnn_success) { if (nullptr != error_primitive) throw error(status, message, *error_primitive); else throw error(status, message, nullptr); } } }; inline primitive::at::operator primitive() const { const_mkldnn_primitive_t output; error::wrap_c_api( mkldnn_primitive_get_output(data.primitive, data.output_index, &output), "could not get an output primitive"); return primitive(const_cast(output), true); } const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { const_mkldnn_primitive_desc_t pd; error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), "could not get primitive descriptor by primitive"); return pd; } /// @} /// @addtogroup cpp_api_enums Common data types and enumerations /// @{ enum round_mode { round_nearest = mkldnn_round_nearest, round_down = mkldnn_round_down, }; inline mkldnn_round_mode_t convert_to_c(round_mode mode) { return static_cast(mode); } enum padding_kind { zero = mkldnn_padding_zero }; inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { return static_cast(kind); } enum prop_kind { forward_training = mkldnn_forward_training, forward_scoring = mkldnn_forward_scoring, forward_inference = mkldnn_forward_inference, forward = mkldnn_forward, backward = mkldnn_backward, backward_data = mkldnn_backward_data, backward_weights = mkldnn_backward_weights, backward_bias = mkldnn_backward_bias }; inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { return static_cast(kind); } enum algorithm { algorithm_undef = mkldnn_alg_kind_undef, convolution_direct = mkldnn_convolution_direct, convolution_winograd = mkldnn_convolution_winograd, deconvolution_direct = mkldnn_deconvolution_direct, deconvolution_winograd = mkldnn_deconvolution_winograd, eltwise_relu = mkldnn_eltwise_relu, eltwise_tanh = mkldnn_eltwise_tanh, eltwise_elu = mkldnn_eltwise_elu, eltwise_square = mkldnn_eltwise_square, eltwise_abs = mkldnn_eltwise_abs, eltwise_sqrt = mkldnn_eltwise_sqrt, eltwise_linear = mkldnn_eltwise_linear, eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, eltwise_soft_relu = mkldnn_eltwise_soft_relu, eltwise_logistic = mkldnn_eltwise_logistic, lrn_across_channels = mkldnn_lrn_across_channels, lrn_within_channel = mkldnn_lrn_within_channel, pooling_max = mkldnn_pooling_max, pooling_avg = mkldnn_pooling_avg, pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, vanilla_rnn = mkldnn_vanilla_rnn, vanilla_lstm = mkldnn_vanilla_lstm, vanilla_gru = mkldnn_vanilla_gru, }; inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { return static_cast(aalgorithm); } enum batch_normalization_flag { use_global_stats = mkldnn_use_global_stats, use_scale_shift = mkldnn_use_scaleshift, omit_stats = mkldnn_omit_stats, fuse_bn_relu = mkldnn_fuse_bn_relu }; inline mkldnn_batch_normalization_flag_t convert_to_c( batch_normalization_flag aflag) { return static_cast(aflag); } enum rnn_direction { unidirectional_left2right = mkldnn_unidirectional_left2right, unidirectional_right2left = mkldnn_unidirectional_right2left, unidirectional = mkldnn_unidirectional, bidirectional_concat = mkldnn_bidirectional_concat, bidirectional_sum = mkldnn_bidirectional_sum, }; inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { return static_cast(adir); } enum query { undef = mkldnn_query_undef, eengine = mkldnn_query_engine, primitive_kind = mkldnn_query_primitive_kind, num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, time_estimate_f64 = mkldnn_query_time_estimate_f64, memory_consumption_s64 = mkldnn_query_memory_consumption_s64, impl_info_str = mkldnn_query_impl_info_str, memory_d = mkldnn_query_memory_d, convolution_d = mkldnn_query_convolution_d, deconvolution_d = mkldnn_query_deconvolution_d, eltwise_d = mkldnn_query_eltwise_d, relu_d = mkldnn_query_relu_d, softmax_d = mkldnn_query_softmax_d, pooling_d = mkldnn_query_pooling_d, lrn_d = mkldnn_query_lrn_d, batch_normalization_d = mkldnn_query_batch_normalization_d, inner_product_d = mkldnn_query_inner_product_d, convolution_relu_d = mkldnn_query_convolution_relu_d, rnn_d = mkldnn_query_rnn_d, input_pd = mkldnn_query_input_pd, output_pd = mkldnn_query_output_pd, src_pd = mkldnn_query_src_pd, diff_src_pd = mkldnn_query_diff_src_pd, weights_pd = mkldnn_query_weights_pd, diff_weights_pd = mkldnn_query_diff_weights_pd, dst_pd = mkldnn_query_dst_pd, diff_dst_pd = mkldnn_query_diff_dst_pd, workspace_pd = mkldnn_query_workspace_pd, }; inline mkldnn_query_t convert_to_c(query aquery) { return static_cast(aquery); } /// @} /// @addtogroup cpp_api_attr Attributes /// @{ #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits { static constexpr auto destructor = &mkldnn_post_ops_destroy; }; #endif struct post_ops : public handle { post_ops() { mkldnn_post_ops_t result; error::wrap_c_api(mkldnn_post_ops_create(&result), "could not create post operation sequence"); reset(result); } int len() const { return mkldnn_post_ops_len(get()); } primitive::kind kind(int index) const { error::wrap_c_api(index < len() ? mkldnn_success : mkldnn_invalid_arguments, "post_ops index is out of range"); return static_cast(mkldnn_post_ops_get_kind(get(), index)); } void append_sum(float scale = 1.) { error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), "could not append sum"); } void get_params_sum(int index, float &scale) const { error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), "could not get sum params"); } void append_eltwise(float scale, algorithm alg, float alpha, float beta) { error::wrap_c_api(mkldnn_post_ops_append_eltwise( get(), scale, convert_to_c(alg), alpha, beta), "could not append eltwise"); } void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const { mkldnn_alg_kind_t c_alg; error::wrap_c_api(mkldnn_post_ops_get_params_eltwise( get(), index, &scale, &c_alg, &alpha, &beta), "could not get eltwise params"); alg = static_cast(c_alg); } }; #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits { static constexpr auto destructor = &mkldnn_primitive_attr_destroy; }; #endif struct primitive_attr : public handle { primitive_attr() { mkldnn_primitive_attr_t result; error::wrap_c_api(mkldnn_primitive_attr_create(&result), "could not create a primitive attr"); reset(result); } round_mode get_int_output_round_mode() const { mkldnn_round_mode_t result; error::wrap_c_api( mkldnn_primitive_attr_get_int_output_round_mode(get(), &result), "could not get int output round mode"); return round_mode(result); } void set_int_output_round_mode(round_mode mode) { error::wrap_c_api(mkldnn_primitive_attr_set_int_output_round_mode( get(), mkldnn::convert_to_c(mode)), "could not set int output round mode"); } void get_output_scales(int &mask, std::vector &scales) const { int count, c_mask; const float *c_scales; error::wrap_c_api(mkldnn_primitive_attr_get_output_scales( get(), &count, &c_mask, &c_scales), "could not get int output scales"); scales.resize(count); mask = c_mask; for (int c = 0; c < count; ++c) scales[c] = c_scales[c]; } void set_output_scales(int mask, const std::vector &scales) { error::wrap_c_api(mkldnn_primitive_attr_set_output_scales( get(), (int)scales.size(), mask, &scales[0]), "could not set int output scales"); } const post_ops get_post_ops() const { post_ops result; const_mkldnn_post_ops_t c_result; error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), "could not get post operation sequence"); result.reset(const_cast(c_result), true); return result; } void set_post_ops(post_ops ops) { error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), "could not set post operation sequence"); } }; /// @} /// @addtogroup cpp_api_engine Engine /// @{ #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits { static constexpr auto destructor = &mkldnn_engine_destroy; }; #endif /// An execution engine. struct engine : public handle { friend class primitive; // gcc bug??? using handle::handle; /// Kinds of engines enum kind { /// An unspecified engine any = mkldnn_any_engine, /// CPU engine cpu = mkldnn_cpu, }; /// Returns the number of engines of a certain kind. /// /// @param akind The kind of engines to count. static size_t get_count(kind akind) { return mkldnn_engine_get_count(convert_to_c(akind)); } /// Constructs an engine. /// /// @param akind The kind of engine to construct. /// @param index The index of the engine. Must be less than the value /// returned by #get_count() for this particular kind of engine. engine(kind akind, size_t index) { mkldnn_engine_t aengine; error::wrap_c_api( mkldnn_engine_create(&aengine, convert_to_c(akind), index), "could not create an engine"); reset(aengine); } explicit engine(const mkldnn_engine_t &aengine) : handle(aengine, true) {} engine(const handle &pd) { mkldnn_engine_t engine_q; error::wrap_c_api( mkldnn_primitive_desc_query( pd.get(), mkldnn::convert_to_c(eengine), 0, &engine_q), "could not get engine from primitive_desc"); reset(engine_q, true); } template static engine query(const primitive_desc &pd) { mkldnn_engine_t engine_q; error::wrap_c_api( mkldnn_primitive_desc_query( pd.get(), mkldnn::convert_to_c(eengine), 0, &engine_q), "could not get engine from primitive_desc"); return engine(engine_q); } private: static mkldnn_engine_kind_t convert_to_c(kind akind) { return static_cast(akind); } }; /// @} /// @addtogroup cpp_api_primitives Primitives /// @{ /// @addtogroup cpp_api_memory Memory /// @{ /// Memory primitive that describes the data. struct memory : public primitive { private: std::shared_ptr _handle; public: typedef std::vector::type> dims; template static void validate_dims(std::vector v) { if (v.size() > TENSOR_MAX_DIMS) throw error(mkldnn_invalid_arguments, "invalid dimensions"); } /// Data type specification. See #mkldnn_data_type_t for a detailed /// description. enum data_type { data_undef = mkldnn_data_type_undef, f32 = mkldnn_f32, s32 = mkldnn_s32, s16 = mkldnn_s16, s8 = mkldnn_s8, u8 = mkldnn_u8, }; /// Memory format specification. See #mkldnn_memory_format_t /// for a detailed description. enum format { format_undef = mkldnn_format_undef, any = mkldnn_any, blocked = mkldnn_blocked, x = mkldnn_x, nc = mkldnn_nc, nchw = mkldnn_nchw, nhwc = mkldnn_nhwc, chwn = mkldnn_chwn, nChw8c = mkldnn_nChw8c, nChw16c = mkldnn_nChw16c, ncdhw = mkldnn_ncdhw, ndhwc = mkldnn_ndhwc, nCdhw16c = mkldnn_nCdhw16c, oi = mkldnn_oi, io = mkldnn_io, oihw = mkldnn_oihw, ihwo = mkldnn_ihwo, hwio = mkldnn_hwio, oidhw = mkldnn_oidhw, OIdhw16i16o = mkldnn_OIdhw16i16o, OIdhw16o16i = mkldnn_OIdhw16o16i, Oidhw16o = mkldnn_Oidhw16o, Odhwi16o = mkldnn_Odhwi16o, oIhw8i = mkldnn_oIhw8i, oIhw16i = mkldnn_oIhw16i, OIhw8i8o = mkldnn_OIhw8i8o, OIhw16i16o = mkldnn_OIhw16i16o, OIhw8o8i = mkldnn_OIhw8o8i, OIhw16o16i = mkldnn_OIhw16o16i, IOhw16o16i = mkldnn_IOhw16o16i, OIhw8i16o2i = mkldnn_OIhw8i16o2i, OIhw8o16i2o = mkldnn_OIhw8o16i2o, OIhw4i16o4i = mkldnn_OIhw4i16o4i, Oihw8o = mkldnn_Oihw8o, Oihw16o = mkldnn_Oihw16o, Ohwi8o = mkldnn_Ohwi8o, Ohwi16o = mkldnn_Ohwi16o, OhIw16o4i = mkldnn_OhIw16o4i, goihw = mkldnn_goihw, hwigo = mkldnn_hwigo, gOIhw8i8o = mkldnn_gOIhw8i8o, gOIhw16i16o = mkldnn_gOIhw16i16o, gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, gOihw8o = mkldnn_gOihw8o, gOihw16o = mkldnn_gOihw16o, gOhwi8o = mkldnn_gOhwi8o, gOhwi16o = mkldnn_gOhwi16o, Goihw8g = mkldnn_Goihw8g, Goihw16g = mkldnn_Goihw16g, gOIhw8o8i = mkldnn_gOIhw8o8i, gOIhw16o16i = mkldnn_gOIhw16o16i, gIOhw16o16i = mkldnn_gIOhw16o16i, gOhIw16o4i = mkldnn_gOhIw16o4i, goidhw = mkldnn_goidhw, gOIdhw16i16o = mkldnn_gOIdhw16i16o, gOIdhw16o16i = mkldnn_gOIdhw16o16i, gOidhw16o = mkldnn_gOidhw16o, gOdhwi16o = mkldnn_gOdhwi16o, ntc = mkldnn_ntc, tnc = mkldnn_tnc, ldsnc = mkldnn_ldsnc, ldigo = mkldnn_ldigo, ldigo_p = mkldnn_ldigo_p, ldgoi = mkldnn_ldgoi, ldgoi_p = mkldnn_ldgoi_p, ldgo = mkldnn_ldgo, wino_fmt = mkldnn_wino_fmt, format_last = mkldnn_format_last, }; /// A memory descriptor. struct desc { friend struct memory; /// The underlying C API data structure. mkldnn_memory_desc_t data; /// Constructs a memory descriptor. /// /// @param adims Data dimensions /// @param adata_type Data precision/type. /// @param aformat Data layout format. desc(dims adims, data_type adata_type, format aformat) { validate_dims(adims); error::wrap_c_api( mkldnn_memory_desc_init(&data, (int)adims.size(), adims.size() == 0 ? nullptr : &adims[0], convert_to_c(adata_type), convert_to_c(aformat)), "could not initialize a memory descriptor"); } /// Constructs a memory descriptor from a C API data structure. /// /// @param adata A C API #mkldnn_memory_desc_t structure. desc(const mkldnn_memory_desc_t &adata) : data(adata) {} }; /// A memory primitive descriptor. struct primitive_desc : public handle { friend struct memory; // TODO: make private primitive_desc() {} /// Constructs a memory primitive descriptor. primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_memory_primitive_desc_create( &result, &adesc.data, aengine.get()), "could not initialize a memory primitive descriptor"); reset(result); } /// Returns the memory primitive descriptor. memory::desc desc() { auto memory_d = mkldnn_primitive_desc_query_memory_d(get()); return memory::desc(*memory_d); } /// Returns the number of bytes required to allocate the memory described /// including the padding area. size_t get_size() const { return mkldnn_memory_primitive_desc_get_size(get()); } bool operator==(const primitive_desc &other) const { return mkldnn_memory_primitive_desc_equal(get(), other.get()); } bool operator!=(const primitive_desc &other) const { return !operator==(other); } engine get_engine() { return engine::query(*this); } }; /// Constructs a memory primitive from a generic primitive. /// /// @param aprimitive The primitive to treat as memory. memory(const primitive &aprimitive) : primitive(aprimitive) {} /// Constructs a memory primitive. /// /// @param adesc Memory primitive descriptor. memory(const primitive_desc &adesc) { mkldnn_primitive_t result; error::wrap_c_api( mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr), "could not create a memory primitive"); reset(result); auto _malloc = [](size_t size, int alignment) { void *ptr; #ifdef _WIN32 ptr = _aligned_malloc(size, alignment); int rc = ((ptr) ? 0 : errno); #else int rc = ::posix_memalign(&ptr, alignment, size); #endif /* _WIN32 */ return (rc == 0) ? (char *)ptr : nullptr; }; auto _free = [](char *p) { #ifdef _WIN32 _aligned_free((void *)p); #else ::free((void *)p); #endif /* _WIN32 */ }; _handle.reset(_malloc(adesc.get_size(), 4096), _free); set_data_handle(_handle.get()); } memory(const primitive_desc &adesc, void *ahandle) { mkldnn_primitive_t result; error::wrap_c_api( mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr), "could not create a memory primitive"); reset(result); set_data_handle(ahandle); } /// Returns the descriptor of the memory primitive. primitive_desc get_primitive_desc() const { primitive_desc adesc; const_mkldnn_primitive_desc_t cdesc; error::wrap_c_api( mkldnn_primitive_get_primitive_desc(get(), &cdesc), "could not get primitive descriptor from a memory primitive"); /* FIXME: no const_cast should be here */ adesc.reset(const_cast(cdesc), true); return adesc; } /// Returns a handle of the data contained in the memory primitive. On /// the CPU engine, this is a pointer to the allocated memory. inline void *get_data_handle() const { void *handle; error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), "could not get native handle"); return handle; } inline void set_data_handle(void *handle) const { error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), "could not set native handle"); } // Must go away or be private: static mkldnn_data_type_t convert_to_c(data_type adata_type) { return static_cast(adata_type); } static mkldnn_memory_format_t convert_to_c(format aformat) { return static_cast(aformat); } }; inline memory::desc zero_md() { mkldnn_memory_desc_t zero; zero.primitive_kind = mkldnn_memory; return memory::desc(zero); } inline memory null_memory(engine eng) { mkldnn::memory::desc zero = zero_md(); return memory({zero, eng}, nullptr); } inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) { const_mkldnn_primitive_desc_t aprimitive_pd; mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd); const mkldnn_memory_desc_t *aprimitive_md = mkldnn_primitive_desc_query_memory_d(aprimitive_pd); return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0)); } inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { return a == memory::convert_to_c(b); } inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { return !(a == b); } inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { return b == a; } inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { return !(a == b); } inline bool operator==(mkldnn_memory_format_t a, memory::format b) { return a == memory::convert_to_c(b); } inline bool operator!=(mkldnn_memory_format_t a, memory::format b) { return !(a == b); } inline bool operator==(memory::format a, mkldnn_memory_format_t b) { return b == a; } inline bool operator!=(memory::format a, mkldnn_memory_format_t b) { return !(a == b); } /// @} /// @addtogroup cpp_api_reorder Reorder /// @{ struct reorder : public primitive { struct primitive_desc : public handle { primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_reorder_primitive_desc_create( &result, input.get(), output.get()), "could not create a reorder primitive descriptor"); reset(result); } primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_reorder_primitive_desc_create_v2( &result, input.get(), output.get(), aattr.get()), "could not create a reorder primitive descriptor"); reset(result); } engine get_engine() { return engine::query(*this); } }; reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {input.data}; const_mkldnn_primitive_t outputs[] = {output.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a reorder primitive"); reset(result); } reorder(const primitive::at &input, const memory &output) { auto input_mpd = memory(input).get_primitive_desc(); auto output_mpd = output.get_primitive_desc(); auto reorder_d = primitive_desc(input_mpd, output_mpd); mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {input.data}; const_mkldnn_primitive_t outputs[] = {output.get()}; error::wrap_c_api( mkldnn_primitive_create(&result, reorder_d.get(), inputs, outputs), "could not create a reorder primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_view View /// @{ struct view : public primitive { struct primitive_desc : public handle { primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_view_primitive_desc_create( &result, input.get(), &dims[0], &offsets[0]), "could not create a view primitive descriptor"); reset(result); } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; view(const primitive_desc &view_pd, primitive::at input) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {input.data}; error::wrap_c_api( mkldnn_primitive_create(&result, view_pd.get(), inputs, nullptr), "could not create a view primitive"); reset(result); } view(memory input, memory::dims dims, memory::dims offsets) { mkldnn_primitive_t result; primitive_desc view_pd(input.get_primitive_desc(), dims, offsets); mkldnn_primitive_at_t inputs[] = {primitive::at(input).data}; error::wrap_c_api( mkldnn_primitive_create(&result, view_pd.get(), inputs, nullptr), "could not create a view primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_concat Concat /// @{ struct concat : public primitive { struct primitive_desc : public handle { std::vector cpp_to_c( std::vector inputs) { std::vector c_api_inputs; c_api_inputs.reserve(inputs.size()); auto convert_to_c = [](memory::primitive_desc d) { return d.get(); }; std::transform(inputs.begin(), inputs.end(), std::back_inserter(c_api_inputs), convert_to_c); return c_api_inputs; } primitive_desc(const memory::desc &output, int concat_dimension, std::vector inputs) { mkldnn_primitive_desc_t result; auto c_api_inputs = cpp_to_c(inputs); error::wrap_c_api( mkldnn_concat_primitive_desc_create(&result, &output.data, (int)c_api_inputs.size(), concat_dimension, &c_api_inputs[0]), "could not create a concat primitive descriptor"); reset(result); } primitive_desc(int concat_dimension, std::vector inputs) { mkldnn_primitive_desc_t result; auto c_api_inputs = cpp_to_c(inputs); error::wrap_c_api( mkldnn_concat_primitive_desc_create(&result, nullptr, (int)c_api_inputs.size(), concat_dimension, &c_api_inputs[0]), "could not create a concat primitive descriptor"); reset(result); } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; concat(const primitive_desc &concat_pd, std::vector &inputs, const memory &output) { mkldnn_primitive_t result; std::vector p_inputs; for (size_t i = 0; i < inputs.size(); i++) p_inputs.push_back(inputs[i].data); const_mkldnn_primitive_t outputs[] = {output.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, concat_pd.get(), &p_inputs[0], outputs), "could not create a concat primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_sum Sum /// @{ struct sum : public primitive { struct primitive_desc : public handle { std::vector cpp_to_c( std::vector inputs) { std::vector c_api_inputs; c_api_inputs.reserve(inputs.size()); auto convert_to_c = [](memory::primitive_desc d) { return d.get(); }; std::transform(inputs.begin(), inputs.end(), std::back_inserter(c_api_inputs), convert_to_c); return c_api_inputs; } primitive_desc(const memory::desc &output, const std::vector &scales, std::vector inputs) { mkldnn_primitive_desc_t result; auto c_api_inputs = cpp_to_c(inputs); error::wrap_c_api( mkldnn_sum_primitive_desc_create(&result, &output.data, (int)c_api_inputs.size(), &scales[0], &c_api_inputs[0]), "could not create a sum primitive descriptor"); reset(result); } primitive_desc(const std::vector &scales, std::vector inputs) { mkldnn_primitive_desc_t result; auto c_api_inputs = cpp_to_c(inputs); error::wrap_c_api( mkldnn_sum_primitive_desc_create(&result, nullptr, (int)c_api_inputs.size(), &scales[0], &c_api_inputs[0]), "could not create a sum primitive descriptor"); reset(result); } /** @deprecated: api backwards compatibility for double scales type */ MKLDNN_DEPRECATED primitive_desc(const memory::desc &output, std::vector scale, std::vector inputs) { mkldnn_primitive_desc_t result; auto c_api_inputs = cpp_to_c(inputs); auto scale_f = scale_to_float(scale); error::wrap_c_api( mkldnn_sum_primitive_desc_create(&result, &output.data, (int)c_api_inputs.size(), &scale_f[0], &c_api_inputs[0]), "could not create a sum primitive descriptor"); reset(result); } /** @deprecated: api backwards compatibility for double scales type */ MKLDNN_DEPRECATED primitive_desc(std::vector scale, std::vector inputs) { mkldnn_primitive_desc_t result; auto c_api_inputs = cpp_to_c(inputs); auto scale_f = scale_to_float(scale); error::wrap_c_api( mkldnn_sum_primitive_desc_create(&result, nullptr, (int)c_api_inputs.size(), &scale_f[0], &c_api_inputs[0]), "could not create a sum primitive descriptor"); reset(result); } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; sum(const primitive_desc &sum_pd, std::vector &inputs, const memory &output) { mkldnn_primitive_t result; std::vector p_inputs; for (size_t i = 0; i < inputs.size(); i++) p_inputs.push_back(inputs[i].data); const_mkldnn_primitive_t outputs[] = {output.get()}; error::wrap_c_api( mkldnn_primitive_create(&result, sum_pd.get(), &p_inputs[0], outputs), "could not create a sum primitive"); reset(result); } private: static std::vector scale_to_float(const std::vector &vd) { std::vector vf(vd.size()); std::transform( vd.begin(), vd.end(), vf.begin(), [=](double x) { return (float)x; }); return vf; } }; /// @} /// @addtogroup cpp_api_convolution Convolution /// @{ struct convolution_forward : public primitive { struct desc { mkldnn_convolution_desc_t data; desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_convolution_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_convolution_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_dilated_convolution_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated convolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_dilated_convolution_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated convolution forward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a convolution forward primitive descriptor"); reset(result); } primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create_v2( &result, &adesc.data, aattr.get(), aengine.get(), nullptr), "could not create a convolution forward primitive descriptor"); reset(result); } memory::primitive_desc src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a convolution forward bias primitive"); reset(result); } convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a convolution forward primitive"); reset(result); } }; struct convolution_backward_data : public primitive { struct desc { mkldnn_convolution_desc_t data; desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_convolution_backward_data_desc_init( &data, convert_to_c(aalgorithm), &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward data descriptor"); } desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_dilated_convolution_backward_data_desc_init( &data, convert_to_c(aalgorithm), &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward data descriptor"); } }; struct primitive_desc : public handle { primitive_desc( const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a convolution backward data primitive descriptor"); reset(result); } memory::primitive_desc diff_src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_src primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {diff_dst.data, weights.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a convolution backward data primitive"); reset(result); } }; struct convolution_backward_weights : public primitive { struct desc { mkldnn_convolution_desc_t data; desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_convolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_convolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_dilated_convolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_dilated_convolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward weights descriptor"); } }; struct primitive_desc : public handle { primitive_desc( const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a convolution backward weights primitive " "descriptor"); reset(result); } memory::primitive_desc src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_weights.get(), diff_bias.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a convolution backward weights primitive"); reset(result); } convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_weights.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a convolution backward weights primitive"); reset(result); } }; struct convolution_relu_forward : public primitive { struct desc { mkldnn_convolution_relu_desc_t data; desc(const convolution_forward::desc conv_desc, const float negative_slope) { error::wrap_c_api( mkldnn_convolution_relu_desc_init( &data, &conv_desc.data, negative_slope), "could not create a convolution_relu_forward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a convolution relu forward descriptor"); reset(result); } engine get_engine() { return engine::query(*this); } }; convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a convolution relu forward primitive"); reset(result); } convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a convolution relu forward primitive"); reset(result); } }; /// @} // /// @addtogroup cpp_api_deconvolution Deconvolution /// @{ struct deconvolution_forward : public primitive { struct desc { mkldnn_deconvolution_desc_t data; desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_deconvolution_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_deconvolution_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution forward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a deconvolution forward primitive descriptor"); reset(result); } primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create_v2( &result, &adesc.data, aattr.get(), aengine.get(), nullptr), "could not create a deconvolution forward primitive descriptor"); reset(result); } memory::primitive_desc src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a deconvolution forward bias primitive"); reset(result); } deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a deconvolution forward primitive"); reset(result); } }; struct deconvolution_backward_data : public primitive { struct desc { mkldnn_deconvolution_desc_t data; desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_deconvolution_backward_data_desc_init( &data, convert_to_c(aalgorithm), &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution backward data descriptor"); } }; struct primitive_desc : public handle { primitive_desc( const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a deconvolution backward data primitive " "descriptor"); reset(result); } memory::primitive_desc diff_src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_src primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; deconvolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {diff_dst.data, weights.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a deconvolution backward data primitive"); reset(result); } }; struct deconvolution_backward_weights : public primitive { struct desc { mkldnn_deconvolution_desc_t data; desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_deconvolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_deconvolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution backward weights descriptor"); } }; struct primitive_desc : public handle { primitive_desc( const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a deconvolution backward weights primitive " "descriptor"); reset(result); } memory::primitive_desc src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_weights.get(), diff_bias.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a deconvolution backward weights primitive"); reset(result); } deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_weights.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a deconvolution backward weights primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_lrn LRN /// @{ struct lrn_forward : public primitive { struct desc { mkldnn_lrn_desc_t data; desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k) { error::wrap_c_api( mkldnn_lrn_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, local_size, alpha, beta, k), "could not create a lrn forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta) { error::wrap_c_api( mkldnn_lrn_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, local_size, alpha, beta, float(1.0)), "could not create a lrn forward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a lrn forward primitive descriptor"); reset(result); } memory::primitive_desc src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc workspace_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t ldesc; const_mkldnn_primitive_desc_t const_ldesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(workspace_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), "could not clone a workspace primitive descriptor"); adesc.reset(ldesc); return adesc; } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data}; const_mkldnn_primitive_t outputs[] = {dst.get(), workspace.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a lrn forward primitive"); reset(result); } lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a lrn forward primitive"); reset(result); } }; struct lrn_backward : public primitive { struct desc { mkldnn_lrn_desc_t data; desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k) { error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, convert_to_c(aalgorithm), &diff_data_desc.data, &data_desc.data, local_size, alpha, beta, k), "could not create a lrn backward descriptor"); } desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta) { error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, convert_to_c(aalgorithm), &diff_data_desc.data, &data_desc.data, local_size, alpha, beta, float(1.0)), "could not create a lrn backward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a backward lrn primitive descriptor"); reset(result); } memory::primitive_desc diff_src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_src primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc workspace_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t ldesc; const_mkldnn_primitive_desc_t const_ldesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(workspace_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), "could not clone a workspace primitive descriptor"); adesc.reset(ldesc); return adesc; } memory::primitive_desc diff_dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff_dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data, workspace.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a lrn backward primitive"); reset(result); } lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a lrn backward primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_pooling Pooling /// @{ struct pooling_forward : public primitive { struct desc { mkldnn_pooling_desc_t data; desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(kernel); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_pooling_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &dst_desc.data, &strides[0], &kernel[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not init a forward pooling descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a forward pooling primitive descriptor"); reset(result); } memory::primitive_desc workspace_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(workspace_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a workspace primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data}; const_mkldnn_primitive_t outputs[] = {dst.get(), nullptr}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a pooling forward primitive"); reset(result); } pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data}; const_mkldnn_primitive_t outputs[] = {dst.get(), workspace.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a pooling forward primitive"); reset(result); } }; struct pooling_backward : public primitive { struct desc { mkldnn_pooling_desc_t data; desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(kernel); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_pooling_backward_desc_init( &data, convert_to_c(aalgorithm), &diff_src_desc.data, &diff_dst_desc.data, &strides[0], &kernel[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not init a backward pooling descriptor"); } }; struct primitive_desc : public handle { primitive_desc( const desc &adesc, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a backward pooling primitive descriptor"); reset(result); } memory::primitive_desc diff_src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff src primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a pooling backward primitive"); reset(result); } pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {diff_dst.data, workspace.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a pooling backward primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_eltwise Eltwise /// @{ struct eltwise_forward : public primitive { struct desc { mkldnn_eltwise_desc_t data; template desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha = 0, T beta = 0) { error::wrap_c_api( mkldnn_eltwise_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), mkldnn::convert_to_c(alg_kind), &src_desc.data, static_cast(alpha), static_cast(beta)), "could not create a eltwise forward descriptor"); } /** @deprecated: api backward compatibility for relu */ template MKLDNN_DEPRECATED desc(prop_kind aprop_kind, const memory::desc &src_desc, T negative_slope) : desc(aprop_kind, eltwise_relu, src_desc, negative_slope) {} }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a eltwise forward primitive descriptor"); reset(result); } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a eltwise forward primitive"); reset(result); } }; typedef eltwise_forward relu_forward; struct eltwise_backward : public primitive { struct desc { mkldnn_eltwise_desc_t data; template desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha = 0, T beta = 0) { error::wrap_c_api( mkldnn_eltwise_backward_desc_init(&data, mkldnn::convert_to_c(alg_kind), &diff_data_desc.data, &data_desc.data, static_cast(alpha), static_cast(beta)), "could not create a eltwise backward descriptor"); } /** @deprecated: api backward compatibility for relu */ template MKLDNN_DEPRECATED desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, T negative_slope) : desc(eltwise_relu, diff_data_desc, data_desc, negative_slope) {} }; struct primitive_desc : public handle { primitive_desc( const desc &adesc, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a eltwise backward primitive descriptor"); reset(result); } memory::primitive_desc diff_src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff src primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a eltwise backward primitive"); reset(result); } }; typedef eltwise_backward relu_backward; /// @} /// @addtogroup cpp_api_softmax Softmax /// @{ struct softmax_forward : public primitive { struct desc { mkldnn_softmax_desc_t data; desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis) { error::wrap_c_api( mkldnn_softmax_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), &data_desc.data, softmax_axis), "could not create a softmax forward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a softmax forward primitive descriptor"); reset(result); } engine get_engine() { return engine::query(*this); } }; softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a softmax forward primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_batch_norm Batch normalization /// @{ struct batch_normalization_forward : public primitive { struct desc { mkldnn_batch_normalization_desc_t data; template desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags) { error::wrap_c_api( mkldnn_batch_normalization_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), &src_desc.data, static_cast(epsilon), flags), "could not create a batch normalization forward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a batch normalization forward " "primitive descriptor"); reset(result); } primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create_v2( &result, &adesc.data, aattr.get(), aengine.get(), nullptr), "could not create a batch normalization forward " "primitive descriptor"); reset(result); } memory::primitive_desc weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t bndesc; const_mkldnn_primitive_desc_t const_bndesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), "could not clone a weights primitive descriptor"); adesc.reset(bndesc); return adesc; } memory::primitive_desc mean_primitive_desc() const { memory::primitive_desc aprimitive_desc; mkldnn_primitive_desc_t bndesc; mkldnn_batch_normalization_desc_t *p; error::wrap_c_api( mkldnn_primitive_desc_query( get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), "could not get a batch-normalization descriptor"); const_mkldnn_primitive_desc_t const_bndesc = (p->flags & use_global_stats) ? mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 1) : mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), "could not clone a mean primitive descriptor"); aprimitive_desc.reset(bndesc); return aprimitive_desc; } memory::primitive_desc variance_primitive_desc() const { memory::primitive_desc aprimitive_desc; mkldnn_primitive_desc_t bndesc; mkldnn_batch_normalization_desc_t *p; error::wrap_c_api( mkldnn_primitive_desc_query( get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), "could not get a batch-normalization descriptor"); const_mkldnn_primitive_desc_t const_bndesc = (p->flags & use_global_stats) ? mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 2) : mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 2); error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), "could not clone a variance primitive descriptor"); aprimitive_desc.reset(bndesc); return aprimitive_desc; } memory::primitive_desc workspace_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(workspace_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a workspace primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data, weights.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization forward primitive"); reset(result); } batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, mean.data, variance.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization forward primitive"); reset(result); } /// @warning batch_normalization_forward has 2 constructors with very /// similar signatures: /// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out /// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out /// The only way to distinguish between those is to explicitly /// cast all input parameters to their type, i.e. to /// const primitive:at &. batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; const_mkldnn_primitive_t outputs[] = { dst.get(), mean.get(), variance.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization forward primitive"); reset(result); } batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; const_mkldnn_primitive_t outputs[] = { dst.get(), mean.get(), variance.get(), workspace.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization forward primitive"); reset(result); } batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data}; const_mkldnn_primitive_t outputs[] = { dst.get(), mean.get(), variance.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization forward primitive"); reset(result); } /// @warning batch_normalization_forward has 2 constructors with very /// similar signatures: /// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out /// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out /// The only way to distinguish between those is to explicitly /// cast all input parameters to their type, i.e. to /// const primitive:at &. /// @note to make users' experience a little bit better this constructor /// checks if whether parameters match corresponding primitive /// descriptor, and if they are not -- call the other (proper) /// constructor. Yeah, this is still very ugly... batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[2] = {src.data}; const_mkldnn_primitive_t outputs[4] = { dst.get(), mean.get(), variance.get(), workspace.get()}; if (1) { // check whether this is the `wrong` constructor const int n_inputs_expected = mkldnn_primitive_desc_query_s32( aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0); const int n_outputs_expected = mkldnn_primitive_desc_query_s32( aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0); if (n_inputs_expected == 2 && n_outputs_expected == 3) { // shift parameters, get rid of workspace, and add weights... auto _weights = dst; inputs[1] = {_weights.get(), 0}; auto _dst = mean, _mean = variance, _variance = workspace; outputs[0] = _dst.get(); outputs[1] = _mean.get(); outputs[2] = _variance.get(); outputs[3] = nullptr; } } error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization forward primitive"); reset(result); } batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization forward primitive"); reset(result); } batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization forward primitive"); reset(result); } }; struct batch_normalization_backward : public primitive { struct desc { mkldnn_batch_normalization_desc_t data; template desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags) { error::wrap_c_api( mkldnn_batch_normalization_backward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), &diff_data_desc.data, &data_desc.data, static_cast(epsilon), flags), "could not create a batch normalization backward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a batch normalization backward primitive " "descriptor"); reset(result); } memory::primitive_desc weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t bndesc; const_mkldnn_primitive_desc_t const_bndesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), "could not clone a weights primitive descriptor"); adesc.reset(bndesc); return adesc; } memory::primitive_desc diff_weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t bndesc; const_mkldnn_primitive_desc_t const_bndesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), "could not clone a diff_weights primitive descriptor"); adesc.reset(bndesc); return adesc; } memory::primitive_desc mean_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t bndesc; const_mkldnn_primitive_desc_t const_bndesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), "could not clone a mean primitive descriptor"); adesc.reset(bndesc); return adesc; } memory::primitive_desc variance_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t bndesc; const_mkldnn_primitive_desc_t const_bndesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 2); error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc, const_bndesc), "could not clone a variance primitive descriptor"); adesc.reset(bndesc); return adesc; } memory::primitive_desc workspace_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(workspace_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a workspace primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; // Prop_kind == backward batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data, diff_dst.data, weights.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get(), diff_weights.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization backward primitive"); reset(result); } // Prop_kind == backward (+ws) batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, mean.data, variance.data, diff_dst.data, weights.data, workspace.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get(), diff_weights.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization backward primitive"); reset(result); } // Prop_kind == backward_data (+ws or +weights) /// @warning This constructor works for backward_data propagation /// - w/ weights but w/o workspace, or /// - w/ workspace but w/o weights batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, mean.data, variance.data, diff_dst.data, weights_or_workspace.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization backward primitive"); reset(result); } // Prop_kind == backward_data batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a batch normalization backward primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_inner_product Inner Product /// @{ struct inner_product_forward : public primitive { struct desc { mkldnn_inner_product_desc_t data; desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc) { error::wrap_c_api(mkldnn_inner_product_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data), "could not create a inner product forward descriptor"); } desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc) { error::wrap_c_api(mkldnn_inner_product_forward_desc_init( &data, mkldnn::convert_to_c(aprop_kind), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data), "could not create a inner product forward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create a inner product forward primitive descriptor"); reset(result); } primitive_desc(const desc &adesc, const primitive_attr &aattr, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create_v2( &result, &adesc.data, aattr.get(), aengine.get(), nullptr), "could not create a inner product " "forward primitive descriptor"); reset(result); } memory::primitive_desc src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a inner product forward primitive"); reset(result); } inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, weights.data}; const_mkldnn_primitive_t outputs[] = {dst.get()}; error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a inner product forward primitive"); reset(result); } }; struct inner_product_backward_data : public primitive { struct desc { mkldnn_inner_product_desc_t data; desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc) { error::wrap_c_api( mkldnn_inner_product_backward_data_desc_init(&data, &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data), "could not create a inner product backward data descriptor"); } }; struct primitive_desc : public handle { primitive_desc( const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a inner product backward data primitive " "descriptor"); reset(result); } memory::primitive_desc diff_dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff dst primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff src primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {diff_dst.data, weights.data}; const_mkldnn_primitive_t outputs[] = {diff_src.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a inner product backward data primitive"); reset(result); } }; struct inner_product_backward_weights : public primitive { struct desc { mkldnn_inner_product_desc_t data; desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc) { error::wrap_c_api( mkldnn_inner_product_backward_weights_desc_init( &data, &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data), "could not create a inner product backward weights descriptor"); } desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc) { error::wrap_c_api( mkldnn_inner_product_backward_weights_desc_init( &data, &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data), "could not create a inner product backward weights descriptor"); } }; struct primitive_desc : public handle { primitive_desc( const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_primitive_desc) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create(&result, &adesc.data, aengine.get(), hint_fwd_primitive_desc.get()), "could not create a inner product backward weights primitive " "descriptor"); reset(result); } memory::primitive_desc diff_dst_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff dst primititve descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_weights_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a diff bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc src_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_weights.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a inner product backward weights primitive"); reset(result); } inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data}; const_mkldnn_primitive_t outputs[] = {diff_weights.get(), diff_bias.get()}; error::wrap_c_api( mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create a inner product backward weights primitive"); reset(result); } }; /// @} /// @addtogroup cpp_api_rnn RNN /// @{ struct rnn_cell { struct desc { mkldnn_rnn_cell_desc_t c_rnn_cell_; desc(algorithm kind, algorithm activation_f) { error::wrap_c_api( mkldnn_rnn_cell_desc_init(&c_rnn_cell_, mkldnn::convert_to_c(kind), mkldnn::convert_to_c(activation_f), 0U, 0, 0), "could not init an rnn cell descriptor"); } desc(algorithm kind) : desc(kind, algorithm::algorithm_undef) {} operator const mkldnn_rnn_cell_desc_t *() const { return &c_rnn_cell_; } algorithm get_cell_kind() const { return algorithm(c_rnn_cell_.cell_kind); } algorithm get_activation() const { return algorithm(c_rnn_cell_.activation_kind); } float get_alpha() const { return c_rnn_cell_.alpha; } void set_alpha(float alpha) { c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; c_rnn_cell_.alpha = alpha; } float get_clipping() const { return c_rnn_cell_.clipping; } void set_clipping(float clipping) { c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; c_rnn_cell_.clipping = clipping; } int get_gates_count() const { return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); } int get_state_count() const { return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); } }; }; struct rnn_forward : public primitive { struct desc { mkldnn_rnn_desc_t data; desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc) { error::wrap_c_api( mkldnn_rnn_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), cell, mkldnn::convert_to_c(direction), &src_layer_desc.data, &src_iter_desc.data, &weights_layer_desc.data, &weights_iter_desc.data, &bias_desc.data, &dst_layer_desc.data, &dst_iter_desc.data), "could not create an RNN forward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create an RNN forward primitive descriptor"); reset(result); } memory::primitive_desc src_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone an src layer primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc src_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src iter primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_src_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 2); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc workspace_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t ldesc; const_mkldnn_primitive_desc_t const_ldesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(workspace_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), "could not clone a workspace primitive descriptor"); adesc.reset(ldesc); return adesc; } memory::primitive_desc dst_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api( mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst last layer primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 1); error::wrap_c_api( mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst last iteration primitive descriptor"); adesc.reset(cdesc); return adesc; } engine get_engine() { return engine::query(*this); } }; rnn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const memory &dst_layer, const memory &dst_iter, const memory &workspace) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[5]; const_mkldnn_primitive_t outputs[3]; int idx = 0; inputs[idx++] = src_layer.data; if (!is_null_memory(src_iter.data.primitive)) inputs[idx++] = src_iter.data; inputs[idx++] = weights_layer.data; inputs[idx++] = weights_iter.data; if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data; idx = 0; outputs[idx++] = dst_layer.get(); if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get(); if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get(); error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create an RNN forward primitive"); reset(result); } }; struct rnn_backward : public primitive { struct desc { mkldnn_rnn_desc_t data; desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc) { error::wrap_c_api( mkldnn_rnn_backward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), cell, mkldnn::convert_to_c(direction), &src_layer_desc.data, &src_iter_desc.data, &weights_layer_desc.data, &weights_iter_desc.data, &bias_desc.data, &dst_layer_desc.data, &dst_iter_desc.data, &diff_src_layer_desc.data, &diff_src_iter_desc.data, &diff_weights_layer_desc.data, &diff_weights_iter_desc.data, &diff_bias_desc.data, &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), "could not create an RNN backward descriptor"); } }; struct primitive_desc : public handle { primitive_desc(const desc &adesc, const engine &aengine) { mkldnn_primitive_desc_t result; error::wrap_c_api( mkldnn_primitive_desc_create( &result, &adesc.data, aengine.get(), nullptr), "could not create an RNN backward primitive descriptor"); reset(result); } memory::primitive_desc src_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone an src layer primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc src_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(src_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src iter primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc weights_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(weights_pd), 2); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 0); error::wrap_c_api( mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst last layer primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc dst_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(dst_pd), 1); error::wrap_c_api( mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst last iteration primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_src_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_src_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone an src_layer primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_src_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_src_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a src iter primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_weights_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_weights_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 1); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a weights primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_bias_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_weights_pd), 2); error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a bias primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_dst_layer_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 0); error::wrap_c_api( mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst last layer primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc diff_dst_iter_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t cdesc; const_mkldnn_primitive_desc_t const_cdesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(diff_dst_pd), 1); error::wrap_c_api( mkldnn_primitive_desc_clone(&cdesc, const_cdesc), "could not clone a dst last iteration primitive descriptor"); adesc.reset(cdesc); return adesc; } memory::primitive_desc workspace_primitive_desc() const { memory::primitive_desc adesc; mkldnn_primitive_desc_t ldesc; const_mkldnn_primitive_desc_t const_ldesc = mkldnn_primitive_desc_query_pd( get(), mkldnn::convert_to_c(workspace_pd), 0); error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), "could not clone a workspace primitive descriptor"); adesc.reset(ldesc); return adesc; } engine get_engine() { return engine::query(*this); } }; // With last iteration (with and without input src_iter) rnn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const primitive::at &dst_layer, const primitive::at &dst_iter, const memory &diff_src_layer, const memory &diff_src_iter, const memory &diff_weights_layer, const memory &diff_weights_iter, const memory &diff_bias, const primitive::at &diff_dst_layer, const primitive::at &diff_dst_iter, const primitive::at &workspace) { mkldnn_primitive_t result; mkldnn_primitive_at_t inputs[10]; const_mkldnn_primitive_t outputs[5]; int idx = 0; inputs[idx] = src_layer.data; if (!is_null_memory(src_iter.data.primitive)) inputs[idx++] = src_iter.data; inputs[idx++] = weights_layer.data; inputs[idx++] = weights_iter.data; if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data; inputs[idx] = dst_layer.data; if (!is_null_memory(dst_iter.data.primitive)) inputs[idx++] = dst_iter.data; inputs[idx] = diff_dst_layer.data; if (!is_null_memory(diff_dst_iter.data.primitive)) inputs[idx++] = diff_dst_iter.data; inputs[idx] = workspace.data; idx = 0; outputs[idx] = diff_src_layer.get(); if (!is_null_memory(diff_src_iter.get())) outputs[idx++] = diff_src_iter.get(); outputs[idx] = diff_weights_layer.get(); outputs[idx] = diff_weights_iter.get(); if (!is_null_memory(diff_bias.get())) outputs[idx] = diff_bias.get(); error::wrap_c_api(mkldnn_primitive_create( &result, aprimitive_desc.get(), inputs, outputs), "could not create an RNN backward primitive"); reset(result); } }; /// @} /// @} Primitives /// @addtogroup cpp_api_stream Stream /// @{ #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits { static constexpr auto destructor = &mkldnn_stream_destroy; }; #endif struct stream : public handle { using handle::handle; enum kind { any = mkldnn_stream_kind_t::mkldnn_any_stream, eager = mkldnn_stream_kind_t::mkldnn_eager, lazy = mkldnn_stream_kind_t::mkldnn_lazy }; static mkldnn_stream_kind_t convert_to_c(kind akind) { return static_cast(akind); } /// Constructs a stream. stream(kind akind) { mkldnn_stream_t astream; error::wrap_c_api(mkldnn_stream_create(&astream, convert_to_c(akind)), "could not create a stream"); reset(astream); } /// Submits a vector of primitives to a stream for computations. /// /// @param primitives The vector of primitives to submit. /// @returns The stream. stream &submit(std::vector primitives) { // TODO: find a proper way to convert vector to // vector if (primitives.size() == 0) return *this; std::vector c_api_primitives; c_api_primitives.reserve(primitives.size()); auto convert_to_c = [](primitive p) { return p.get(); }; std::transform(primitives.begin(), primitives.end(), std::back_inserter(c_api_primitives), convert_to_c); mkldnn_primitive_t c_api_error_primitive; error::wrap_c_api(mkldnn_stream_submit(get(), c_api_primitives.size(), &c_api_primitives[0], &c_api_error_primitive), "could not submit primitives to a stream", &c_api_error_primitive); return *this; } /// Waits for all computations submitted to the stream to complete. /// /// @param block Specifies whether the operation should wait indefinitely or /// return /// immediately. /// @returns @c true if all computations completed. /// @returns @c false if not all computations completed. bool wait(bool block = true) { mkldnn_primitive_t c_api_error_primitive; mkldnn_status_t status = mkldnn_stream_wait(get(), block, &c_api_error_primitive); if (status != mkldnn_success && status != mkldnn_try_again) error::wrap_c_api( status, "could not wait on a stream", &c_api_error_primitive); return (status == mkldnn_success); } stream &rerun() { mkldnn_primitive_t c_api_error_primitive; error::wrap_c_api(mkldnn_stream_rerun(get(), &c_api_error_primitive), "could not rerun a stream", &c_api_error_primitive); return *this; } }; /// @} /// @} C++ API } // namespace mkldnn #endif