Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c72a4f4d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c72a4f4d
编写于
5月 14, 2018
作者:
Y
yuyang18
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into feature/exec_strategy
上级
08295f98
8231960f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
4264 addition
and
7 deletion
+4264
-7
cmake/external/mkldnn.cmake
cmake/external/mkldnn.cmake
+2
-0
patches/mkldnn.hpp
patches/mkldnn.hpp
+4252
-0
python/paddle/fluid/tests/unittests/test_network_with_dtype.py
...n/paddle/fluid/tests/unittests/test_network_with_dtype.py
+10
-7
未找到文件。
cmake/external/mkldnn.cmake
浏览文件 @
c72a4f4d
...
...
@@ -56,6 +56,8 @@ ExternalProject_Add(
GIT_TAG
"v0.14"
PREFIX
${
MKLDNN_SOURCES_DIR
}
UPDATE_COMMAND
""
# Patch MKLDNN to compile with gcc 4.8, the related issue is in intel/mkl-dnn#237.
PATCH_COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
CMAKE_CURRENT_SOURCE_DIR
}
/patches/mkldnn.hpp
${
MKLDNN_SOURCES_DIR
}
/src/extern_mkldnn/include/mkldnn.hpp
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
MKLDNN_INSTALL_DIR
}
CMAKE_ARGS -DCMAKE_BUILD_TYPE=
${
CMAKE_BUILD_TYPE
}
CMAKE_ARGS -DMKLROOT=
${
MKLML_ROOT
}
...
...
patches/mkldnn.hpp
0 → 100644
浏览文件 @
c72a4f4d
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MKLDNN_HPP
#define MKLDNN_HPP
#ifndef DOXYGEN_SHOULD_SKIP_THIS
#include <stdlib.h>
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include "mkldnn.h"
#endif
namespace
mkldnn
{
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_api_utils Utils
/// @{
/// A class that provides the destructor for an Intel(R) MKL-DNN C handle
template
<
typename
T
>
class
handle_traits
{};
/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
/// can be passed by value. This class enables wrapping:
/// - Newly constructed handles.
/// @n In this case, the constructed handle uses reference counting provided
/// by @p std::shared_ptr with a proper deleter function specified through
/// the @p handle_traits class.
/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
/// example, through #mkldnn_primitive_get_output()).
/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
/// deleter because it is assumed that the handle wrapper for the original
/// object deletes the handle (this model is similar to @p std::weak_ptr).
template
<
typename
T
,
typename
traits
=
handle_traits
<
T
>
>
class
handle
{
private:
std
::
shared_ptr
<
typename
std
::
remove_pointer
<
T
>::
type
>
_data
;
handle
(
const
handle
&&
)
=
delete
;
handle
&
operator
=
(
const
handle
&&
other
)
=
delete
;
protected:
/// Constructs a C handle wrapper.
/// @param t The C handle to wrap.
/// @param weak A flag to specify whether to construct a weak wrapper.
handle
(
T
t
=
0
,
bool
weak
=
false
)
:
_data
(
0
)
{
reset
(
t
,
weak
);
}
bool
operator
==
(
const
T
other
)
const
{
return
other
==
_data
.
get
();
}
bool
operator
!=
(
const
T
other
)
const
{
return
!
(
*
this
==
other
);
}
public:
handle
(
const
handle
&
other
)
:
_data
(
other
.
_data
)
{}
handle
&
operator
=
(
const
handle
&
other
)
{
_data
=
other
.
_data
;
return
*
this
;
}
/// Resets the value of a C handle.
/// @param t The new value of the C handle.
/// @param weak A flag to specify whether the wrapper should be weak.
void
reset
(
T
t
,
bool
weak
=
false
)
{
auto
dummy_destructor
=
[](
T
)
{
return
decltype
(
traits
::
destructor
(
0
))(
0
);
};
_data
.
reset
(
t
,
weak
?
dummy_destructor
:
traits
::
destructor
);
}
/// Returns the value of the underlying C handle.
T
get
()
const
{
return
_data
.
get
();
}
bool
operator
==
(
const
handle
&
other
)
const
{
return
other
.
_data
.
get
()
==
_data
.
get
();
}
bool
operator
!=
(
const
handle
&
other
)
const
{
return
!
(
*
this
==
other
);
}
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_primitive_desc_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_primitive_desc_destroy
;
};
template
<
>
struct
handle_traits
<
mkldnn_primitive_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_primitive_destroy
;
};
#endif
/// Base class for all computational primitives.
class
primitive
:
public
handle
<
mkldnn_primitive_t
>
{
friend
struct
error
;
friend
struct
stream
;
friend
class
primitive_at
;
using
handle
::
handle
;
public:
/// A proxy to C primitive kind enum
enum
class
kind
{
undefined_primitive
=
mkldnn_undefined_primitive
,
memory
=
mkldnn_memory
,
view
=
mkldnn_view
,
reorder
=
mkldnn_reorder
,
concat
=
mkldnn_concat
,
concat_inplace
=
mkldnn_concat_inplace
,
sum
=
mkldnn_sum
,
convolution
=
mkldnn_convolution
,
deconvolution
=
mkldnn_deconvolution
,
eltwise
=
mkldnn_eltwise
,
relu
=
mkldnn_relu
,
softmax
=
mkldnn_softmax
,
pooling
=
mkldnn_pooling
,
lrn
=
mkldnn_lrn
,
batch_normalization
=
mkldnn_batch_normalization
,
inner_product
=
mkldnn_inner_product
,
convolution_relu
=
mkldnn_convolution_relu
,
rnn
=
mkldnn_rnn
,
};
/// A wrapper structure to specify a particular output of a primitive.
struct
at
{
/// The underlying C API structure.
mkldnn_primitive_at_t
data
;
/// Constructs a wrapper specifying @p aprimitive output with index @p
/// at.
///
/// @param aprimitive The target primitive.
/// @param at The output index.
at
(
const
primitive
&
aprimitive
,
size_t
at
=
0
)
:
data
(
mkldnn_primitive_at
(
aprimitive
.
get
(),
at
))
{}
/// Returns the specified output.
inline
operator
primitive
()
const
;
};
/// Returns the descriptor of the underlying C API primitive
inline
const_mkldnn_primitive_desc_t
get_primitive_desc
()
const
;
// TODO: use the C++ API wrapper structure.
};
inline
mkldnn_primitive_kind_t
convert_to_c
(
primitive
::
kind
akind
)
{
return
static_cast
<
mkldnn_primitive_kind_t
>
(
akind
);
}
/// Intel(R) MKL-DNN exception class.
///
/// This class captures the status returned by the failed C API function, error
/// message, and, optionally, handle of the primitive that caused the error.
struct
error
:
public
std
::
exception
{
mkldnn_status_t
status
;
std
::
string
message
;
primitive
error_primitive
;
/// Constructs an error instance.
///
/// @param astatus The error status returned by the C API.
/// @param amessage The error message.
/// @param aerror_primitive (optional) A C handle of the primitive that
/// caused the error.
error
(
mkldnn_status_t
astatus
,
std
::
string
amessage
,
mkldnn_primitive_t
aerror_primitive
=
0
)
:
status
(
astatus
),
message
(
amessage
),
error_primitive
(
aerror_primitive
,
true
)
{}
/// A convenience function for wrapping calls to the C API. Checks the
/// return status and throws an #error in case of failure.
///
/// @param status The error status returned by the C API.
/// @param message The error message.
/// @param error_primitive (optional) A C handle of the primitive that
/// caused the error.
static
void
wrap_c_api
(
mkldnn_status_t
status
,
std
::
string
message
,
mkldnn_primitive_t
*
error_primitive
=
0
)
{
if
(
status
!=
mkldnn_success
)
{
if
(
nullptr
!=
error_primitive
)
throw
error
(
status
,
message
,
*
error_primitive
);
else
throw
error
(
status
,
message
,
nullptr
);
}
}
};
inline
primitive
::
at
::
operator
primitive
()
const
{
const_mkldnn_primitive_t
output
;
error
::
wrap_c_api
(
mkldnn_primitive_get_output
(
data
.
primitive
,
data
.
output_index
,
&
output
),
"could not get an output primitive"
);
return
primitive
(
const_cast
<
mkldnn_primitive_t
>
(
output
),
true
);
}
const_mkldnn_primitive_desc_t
primitive
::
get_primitive_desc
()
const
{
const_mkldnn_primitive_desc_t
pd
;
error
::
wrap_c_api
(
mkldnn_primitive_get_primitive_desc
(
get
(),
&
pd
),
"could not get primitive descriptor by primitive"
);
return
pd
;
}
/// @}
/// @addtogroup cpp_api_enums Common data types and enumerations
/// @{
enum
round_mode
{
round_nearest
=
mkldnn_round_nearest
,
round_down
=
mkldnn_round_down
,
};
inline
mkldnn_round_mode_t
convert_to_c
(
round_mode
mode
)
{
return
static_cast
<
mkldnn_round_mode_t
>
(
mode
);
}
enum
padding_kind
{
zero
=
mkldnn_padding_zero
};
inline
mkldnn_padding_kind_t
convert_to_c
(
padding_kind
kind
)
{
return
static_cast
<
mkldnn_padding_kind_t
>
(
kind
);
}
enum
prop_kind
{
forward_training
=
mkldnn_forward_training
,
forward_scoring
=
mkldnn_forward_scoring
,
forward_inference
=
mkldnn_forward_inference
,
forward
=
mkldnn_forward
,
backward
=
mkldnn_backward
,
backward_data
=
mkldnn_backward_data
,
backward_weights
=
mkldnn_backward_weights
,
backward_bias
=
mkldnn_backward_bias
};
inline
mkldnn_prop_kind_t
convert_to_c
(
prop_kind
kind
)
{
return
static_cast
<
mkldnn_prop_kind_t
>
(
kind
);
}
enum
algorithm
{
algorithm_undef
=
mkldnn_alg_kind_undef
,
convolution_direct
=
mkldnn_convolution_direct
,
convolution_winograd
=
mkldnn_convolution_winograd
,
deconvolution_direct
=
mkldnn_deconvolution_direct
,
deconvolution_winograd
=
mkldnn_deconvolution_winograd
,
eltwise_relu
=
mkldnn_eltwise_relu
,
eltwise_tanh
=
mkldnn_eltwise_tanh
,
eltwise_elu
=
mkldnn_eltwise_elu
,
eltwise_square
=
mkldnn_eltwise_square
,
eltwise_abs
=
mkldnn_eltwise_abs
,
eltwise_sqrt
=
mkldnn_eltwise_sqrt
,
eltwise_linear
=
mkldnn_eltwise_linear
,
eltwise_bounded_relu
=
mkldnn_eltwise_bounded_relu
,
eltwise_soft_relu
=
mkldnn_eltwise_soft_relu
,
eltwise_logistic
=
mkldnn_eltwise_logistic
,
lrn_across_channels
=
mkldnn_lrn_across_channels
,
lrn_within_channel
=
mkldnn_lrn_within_channel
,
pooling_max
=
mkldnn_pooling_max
,
pooling_avg
=
mkldnn_pooling_avg
,
pooling_avg_include_padding
=
mkldnn_pooling_avg_include_padding
,
pooling_avg_exclude_padding
=
mkldnn_pooling_avg_exclude_padding
,
vanilla_rnn
=
mkldnn_vanilla_rnn
,
vanilla_lstm
=
mkldnn_vanilla_lstm
,
vanilla_gru
=
mkldnn_vanilla_gru
,
};
inline
mkldnn_alg_kind_t
convert_to_c
(
algorithm
aalgorithm
)
{
return
static_cast
<
mkldnn_alg_kind_t
>
(
aalgorithm
);
}
enum
batch_normalization_flag
{
use_global_stats
=
mkldnn_use_global_stats
,
use_scale_shift
=
mkldnn_use_scaleshift
,
omit_stats
=
mkldnn_omit_stats
,
fuse_bn_relu
=
mkldnn_fuse_bn_relu
};
inline
mkldnn_batch_normalization_flag_t
convert_to_c
(
batch_normalization_flag
aflag
)
{
return
static_cast
<
mkldnn_batch_normalization_flag_t
>
(
aflag
);
}
enum
rnn_direction
{
unidirectional_left2right
=
mkldnn_unidirectional_left2right
,
unidirectional_right2left
=
mkldnn_unidirectional_right2left
,
unidirectional
=
mkldnn_unidirectional
,
bidirectional_concat
=
mkldnn_bidirectional_concat
,
bidirectional_sum
=
mkldnn_bidirectional_sum
,
};
inline
mkldnn_rnn_direction_t
convert_to_c
(
rnn_direction
adir
)
{
return
static_cast
<
mkldnn_rnn_direction_t
>
(
adir
);
}
enum
query
{
undef
=
mkldnn_query_undef
,
eengine
=
mkldnn_query_engine
,
primitive_kind
=
mkldnn_query_primitive_kind
,
num_of_inputs_s32
=
mkldnn_query_num_of_inputs_s32
,
num_of_outputs_s32
=
mkldnn_query_num_of_outputs_s32
,
time_estimate_f64
=
mkldnn_query_time_estimate_f64
,
memory_consumption_s64
=
mkldnn_query_memory_consumption_s64
,
impl_info_str
=
mkldnn_query_impl_info_str
,
memory_d
=
mkldnn_query_memory_d
,
convolution_d
=
mkldnn_query_convolution_d
,
deconvolution_d
=
mkldnn_query_deconvolution_d
,
eltwise_d
=
mkldnn_query_eltwise_d
,
relu_d
=
mkldnn_query_relu_d
,
softmax_d
=
mkldnn_query_softmax_d
,
pooling_d
=
mkldnn_query_pooling_d
,
lrn_d
=
mkldnn_query_lrn_d
,
batch_normalization_d
=
mkldnn_query_batch_normalization_d
,
inner_product_d
=
mkldnn_query_inner_product_d
,
convolution_relu_d
=
mkldnn_query_convolution_relu_d
,
rnn_d
=
mkldnn_query_rnn_d
,
input_pd
=
mkldnn_query_input_pd
,
output_pd
=
mkldnn_query_output_pd
,
src_pd
=
mkldnn_query_src_pd
,
diff_src_pd
=
mkldnn_query_diff_src_pd
,
weights_pd
=
mkldnn_query_weights_pd
,
diff_weights_pd
=
mkldnn_query_diff_weights_pd
,
dst_pd
=
mkldnn_query_dst_pd
,
diff_dst_pd
=
mkldnn_query_diff_dst_pd
,
workspace_pd
=
mkldnn_query_workspace_pd
,
};
inline
mkldnn_query_t
convert_to_c
(
query
aquery
)
{
return
static_cast
<
mkldnn_query_t
>
(
aquery
);
}
/// @}
/// @addtogroup cpp_api_attr Attributes
/// @{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_post_ops_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_post_ops_destroy
;
};
#endif
struct
post_ops
:
public
handle
<
mkldnn_post_ops_t
>
{
post_ops
()
{
mkldnn_post_ops_t
result
;
error
::
wrap_c_api
(
mkldnn_post_ops_create
(
&
result
),
"could not create post operation sequence"
);
reset
(
result
);
}
int
len
()
const
{
return
mkldnn_post_ops_len
(
get
());
}
primitive
::
kind
kind
(
int
index
)
const
{
error
::
wrap_c_api
(
index
<
len
()
?
mkldnn_success
:
mkldnn_invalid_arguments
,
"post_ops index is out of range"
);
return
static_cast
<
primitive
::
kind
>
(
mkldnn_post_ops_get_kind
(
get
(),
index
));
}
void
append_sum
(
float
scale
=
1.
)
{
error
::
wrap_c_api
(
mkldnn_post_ops_append_sum
(
get
(),
scale
),
"could not append sum"
);
}
void
get_params_sum
(
int
index
,
float
&
scale
)
const
{
error
::
wrap_c_api
(
mkldnn_post_ops_get_params_sum
(
get
(),
index
,
&
scale
),
"could not get sum params"
);
}
void
append_eltwise
(
float
scale
,
algorithm
alg
,
float
alpha
,
float
beta
)
{
error
::
wrap_c_api
(
mkldnn_post_ops_append_eltwise
(
get
(),
scale
,
convert_to_c
(
alg
),
alpha
,
beta
),
"could not append eltwise"
);
}
void
get_params_eltwise
(
int
index
,
float
&
scale
,
algorithm
&
alg
,
float
&
alpha
,
float
&
beta
)
const
{
mkldnn_alg_kind_t
c_alg
;
error
::
wrap_c_api
(
mkldnn_post_ops_get_params_eltwise
(
get
(),
index
,
&
scale
,
&
c_alg
,
&
alpha
,
&
beta
),
"could not get eltwise params"
);
alg
=
static_cast
<
algorithm
>
(
c_alg
);
}
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_primitive_attr_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_primitive_attr_destroy
;
};
#endif
struct
primitive_attr
:
public
handle
<
mkldnn_primitive_attr_t
>
{
primitive_attr
()
{
mkldnn_primitive_attr_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_attr_create
(
&
result
),
"could not create a primitive attr"
);
reset
(
result
);
}
round_mode
get_int_output_round_mode
()
const
{
mkldnn_round_mode_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_attr_get_int_output_round_mode
(
get
(),
&
result
),
"could not get int output round mode"
);
return
round_mode
(
result
);
}
void
set_int_output_round_mode
(
round_mode
mode
)
{
error
::
wrap_c_api
(
mkldnn_primitive_attr_set_int_output_round_mode
(
get
(),
mkldnn
::
convert_to_c
(
mode
)),
"could not set int output round mode"
);
}
void
get_output_scales
(
int
&
mask
,
std
::
vector
<
float
>
&
scales
)
const
{
int
count
,
c_mask
;
const
float
*
c_scales
;
error
::
wrap_c_api
(
mkldnn_primitive_attr_get_output_scales
(
get
(),
&
count
,
&
c_mask
,
&
c_scales
),
"could not get int output scales"
);
scales
.
resize
(
count
);
mask
=
c_mask
;
for
(
int
c
=
0
;
c
<
count
;
++
c
)
scales
[
c
]
=
c_scales
[
c
];
}
void
set_output_scales
(
int
mask
,
const
std
::
vector
<
float
>
&
scales
)
{
error
::
wrap_c_api
(
mkldnn_primitive_attr_set_output_scales
(
get
(),
(
int
)
scales
.
size
(),
mask
,
&
scales
[
0
]),
"could not set int output scales"
);
}
const
post_ops
get_post_ops
()
const
{
post_ops
result
;
const_mkldnn_post_ops_t
c_result
;
error
::
wrap_c_api
(
mkldnn_primitive_attr_get_post_ops
(
get
(),
&
c_result
),
"could not get post operation sequence"
);
result
.
reset
(
const_cast
<
mkldnn_post_ops_t
>
(
c_result
),
true
);
return
result
;
}
void
set_post_ops
(
post_ops
ops
)
{
error
::
wrap_c_api
(
mkldnn_primitive_attr_set_post_ops
(
get
(),
ops
.
get
()),
"could not set post operation sequence"
);
}
};
/// @}
/// @addtogroup cpp_api_engine Engine
/// @{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_engine_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_engine_destroy
;
};
#endif
/// An execution engine.
struct
engine
:
public
handle
<
mkldnn_engine_t
>
{
friend
class
primitive
;
// gcc bug??? using handle::handle;
/// Kinds of engines
enum
kind
{
/// An unspecified engine
any
=
mkldnn_any_engine
,
/// CPU engine
cpu
=
mkldnn_cpu
,
};
/// Returns the number of engines of a certain kind.
///
/// @param akind The kind of engines to count.
static
size_t
get_count
(
kind
akind
)
{
return
mkldnn_engine_get_count
(
convert_to_c
(
akind
));
}
/// Constructs an engine.
///
/// @param akind The kind of engine to construct.
/// @param index The index of the engine. Must be less than the value
/// returned by #get_count() for this particular kind of engine.
engine
(
kind
akind
,
size_t
index
)
{
mkldnn_engine_t
aengine
;
error
::
wrap_c_api
(
mkldnn_engine_create
(
&
aengine
,
convert_to_c
(
akind
),
index
),
"could not create an engine"
);
reset
(
aengine
);
}
explicit
engine
(
const
mkldnn_engine_t
&
aengine
)
:
handle
(
aengine
,
true
)
{}
engine
(
const
handle
<
mkldnn_primitive_desc_t
>
&
pd
)
{
mkldnn_engine_t
engine_q
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_query
(
pd
.
get
(),
mkldnn
::
convert_to_c
(
eengine
),
0
,
&
engine_q
),
"could not get engine from primitive_desc"
);
reset
(
engine_q
,
true
);
}
template
<
class
primitive_desc
>
static
engine
query
(
const
primitive_desc
&
pd
)
{
mkldnn_engine_t
engine_q
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_query
(
pd
.
get
(),
mkldnn
::
convert_to_c
(
eengine
),
0
,
&
engine_q
),
"could not get engine from primitive_desc"
);
return
engine
(
engine_q
);
}
private:
static
mkldnn_engine_kind_t
convert_to_c
(
kind
akind
)
{
return
static_cast
<
mkldnn_engine_kind_t
>
(
akind
);
}
};
/// @}
/// @addtogroup cpp_api_primitives Primitives
/// @{
/// @addtogroup cpp_api_memory Memory
/// @{
/// Memory primitive that describes the data.
struct
memory
:
public
primitive
{
private:
std
::
shared_ptr
<
char
>
_handle
;
public:
typedef
std
::
vector
<
std
::
remove_extent
<
mkldnn_dims_t
>::
type
>
dims
;
template
<
typename
T
>
static
void
validate_dims
(
std
::
vector
<
T
>
v
)
{
if
(
v
.
size
()
>
TENSOR_MAX_DIMS
)
throw
error
(
mkldnn_invalid_arguments
,
"invalid dimensions"
);
}
/// Data type specification. See #mkldnn_data_type_t for a detailed
/// description.
enum
data_type
{
data_undef
=
mkldnn_data_type_undef
,
f32
=
mkldnn_f32
,
s32
=
mkldnn_s32
,
s16
=
mkldnn_s16
,
s8
=
mkldnn_s8
,
u8
=
mkldnn_u8
,
};
/// Memory format specification. See #mkldnn_memory_format_t
/// for a detailed description.
enum
format
{
format_undef
=
mkldnn_format_undef
,
any
=
mkldnn_any
,
blocked
=
mkldnn_blocked
,
x
=
mkldnn_x
,
nc
=
mkldnn_nc
,
nchw
=
mkldnn_nchw
,
nhwc
=
mkldnn_nhwc
,
chwn
=
mkldnn_chwn
,
nChw8c
=
mkldnn_nChw8c
,
nChw16c
=
mkldnn_nChw16c
,
ncdhw
=
mkldnn_ncdhw
,
ndhwc
=
mkldnn_ndhwc
,
nCdhw16c
=
mkldnn_nCdhw16c
,
oi
=
mkldnn_oi
,
io
=
mkldnn_io
,
oihw
=
mkldnn_oihw
,
ihwo
=
mkldnn_ihwo
,
hwio
=
mkldnn_hwio
,
oidhw
=
mkldnn_oidhw
,
OIdhw16i16o
=
mkldnn_OIdhw16i16o
,
OIdhw16o16i
=
mkldnn_OIdhw16o16i
,
Oidhw16o
=
mkldnn_Oidhw16o
,
Odhwi16o
=
mkldnn_Odhwi16o
,
oIhw8i
=
mkldnn_oIhw8i
,
oIhw16i
=
mkldnn_oIhw16i
,
OIhw8i8o
=
mkldnn_OIhw8i8o
,
OIhw16i16o
=
mkldnn_OIhw16i16o
,
OIhw8o8i
=
mkldnn_OIhw8o8i
,
OIhw16o16i
=
mkldnn_OIhw16o16i
,
IOhw16o16i
=
mkldnn_IOhw16o16i
,
OIhw8i16o2i
=
mkldnn_OIhw8i16o2i
,
OIhw8o16i2o
=
mkldnn_OIhw8o16i2o
,
OIhw4i16o4i
=
mkldnn_OIhw4i16o4i
,
Oihw8o
=
mkldnn_Oihw8o
,
Oihw16o
=
mkldnn_Oihw16o
,
Ohwi8o
=
mkldnn_Ohwi8o
,
Ohwi16o
=
mkldnn_Ohwi16o
,
OhIw16o4i
=
mkldnn_OhIw16o4i
,
goihw
=
mkldnn_goihw
,
hwigo
=
mkldnn_hwigo
,
gOIhw8i8o
=
mkldnn_gOIhw8i8o
,
gOIhw16i16o
=
mkldnn_gOIhw16i16o
,
gOIhw8i16o2i
=
mkldnn_gOIhw8i16o2i
,
gOIhw8o16i2o
=
mkldnn_gOIhw8o16i2o
,
gOIhw4i16o4i
=
mkldnn_gOIhw4i16o4i
,
gOihw8o
=
mkldnn_gOihw8o
,
gOihw16o
=
mkldnn_gOihw16o
,
gOhwi8o
=
mkldnn_gOhwi8o
,
gOhwi16o
=
mkldnn_gOhwi16o
,
Goihw8g
=
mkldnn_Goihw8g
,
Goihw16g
=
mkldnn_Goihw16g
,
gOIhw8o8i
=
mkldnn_gOIhw8o8i
,
gOIhw16o16i
=
mkldnn_gOIhw16o16i
,
gIOhw16o16i
=
mkldnn_gIOhw16o16i
,
gOhIw16o4i
=
mkldnn_gOhIw16o4i
,
goidhw
=
mkldnn_goidhw
,
gOIdhw16i16o
=
mkldnn_gOIdhw16i16o
,
gOIdhw16o16i
=
mkldnn_gOIdhw16o16i
,
gOidhw16o
=
mkldnn_gOidhw16o
,
gOdhwi16o
=
mkldnn_gOdhwi16o
,
ntc
=
mkldnn_ntc
,
tnc
=
mkldnn_tnc
,
ldsnc
=
mkldnn_ldsnc
,
ldigo
=
mkldnn_ldigo
,
ldigo_p
=
mkldnn_ldigo_p
,
ldgoi
=
mkldnn_ldgoi
,
ldgoi_p
=
mkldnn_ldgoi_p
,
ldgo
=
mkldnn_ldgo
,
wino_fmt
=
mkldnn_wino_fmt
,
format_last
=
mkldnn_format_last
,
};
/// A memory descriptor.
struct
desc
{
friend
struct
memory
;
/// The underlying C API data structure.
mkldnn_memory_desc_t
data
;
/// Constructs a memory descriptor.
///
/// @param adims Data dimensions
/// @param adata_type Data precision/type.
/// @param aformat Data layout format.
desc
(
dims
adims
,
data_type
adata_type
,
format
aformat
)
{
validate_dims
(
adims
);
error
::
wrap_c_api
(
mkldnn_memory_desc_init
(
&
data
,
(
int
)
adims
.
size
(),
adims
.
size
()
==
0
?
nullptr
:
&
adims
[
0
],
convert_to_c
(
adata_type
),
convert_to_c
(
aformat
)),
"could not initialize a memory descriptor"
);
}
/// Constructs a memory descriptor from a C API data structure.
///
/// @param adata A C API #mkldnn_memory_desc_t structure.
desc
(
const
mkldnn_memory_desc_t
&
adata
)
:
data
(
adata
)
{}
};
/// A memory primitive descriptor.
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
friend
struct
memory
;
// TODO: make private
primitive_desc
()
{}
/// Constructs a memory primitive descriptor.
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_memory_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
()),
"could not initialize a memory primitive descriptor"
);
reset
(
result
);
}
/// Returns the memory primitive descriptor.
memory
::
desc
desc
()
{
auto
memory_d
=
mkldnn_primitive_desc_query_memory_d
(
get
());
return
memory
::
desc
(
*
memory_d
);
}
/// Returns the number of bytes required to allocate the memory described
/// including the padding area.
size_t
get_size
()
const
{
return
mkldnn_memory_primitive_desc_get_size
(
get
());
}
bool
operator
==
(
const
primitive_desc
&
other
)
const
{
return
mkldnn_memory_primitive_desc_equal
(
get
(),
other
.
get
());
}
bool
operator
!=
(
const
primitive_desc
&
other
)
const
{
return
!
operator
==
(
other
);
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
/// Constructs a memory primitive from a generic primitive.
///
/// @param aprimitive The primitive to treat as memory.
memory
(
const
primitive
&
aprimitive
)
:
primitive
(
aprimitive
)
{}
/// Constructs a memory primitive.
///
/// @param adesc Memory primitive descriptor.
memory
(
const
primitive_desc
&
adesc
)
{
mkldnn_primitive_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
adesc
.
get
(),
nullptr
,
nullptr
),
"could not create a memory primitive"
);
reset
(
result
);
auto
_malloc
=
[](
size_t
size
,
int
alignment
)
{
void
*
ptr
;
#ifdef _WIN32
ptr
=
_aligned_malloc
(
size
,
alignment
);
int
rc
=
((
ptr
)
?
0
:
errno
);
#else
int
rc
=
::
posix_memalign
(
&
ptr
,
alignment
,
size
);
#endif
/* _WIN32 */
return
(
rc
==
0
)
?
(
char
*
)
ptr
:
nullptr
;
};
auto
_free
=
[](
char
*
p
)
{
#ifdef _WIN32
_aligned_free
((
void
*
)
p
);
#else
::
free
((
void
*
)
p
);
#endif
/* _WIN32 */
};
_handle
.
reset
(
_malloc
(
adesc
.
get_size
(),
4096
),
_free
);
set_data_handle
(
_handle
.
get
());
}
memory
(
const
primitive_desc
&
adesc
,
void
*
ahandle
)
{
mkldnn_primitive_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
adesc
.
get
(),
nullptr
,
nullptr
),
"could not create a memory primitive"
);
reset
(
result
);
set_data_handle
(
ahandle
);
}
/// Returns the descriptor of the memory primitive.
primitive_desc
get_primitive_desc
()
const
{
primitive_desc
adesc
;
const_mkldnn_primitive_desc_t
cdesc
;
error
::
wrap_c_api
(
mkldnn_primitive_get_primitive_desc
(
get
(),
&
cdesc
),
"could not get primitive descriptor from a memory primitive"
);
/* FIXME: no const_cast should be here */
adesc
.
reset
(
const_cast
<
mkldnn_primitive_desc_t
>
(
cdesc
),
true
);
return
adesc
;
}
/// Returns a handle of the data contained in the memory primitive. On
/// the CPU engine, this is a pointer to the allocated memory.
inline
void
*
get_data_handle
()
const
{
void
*
handle
;
error
::
wrap_c_api
(
mkldnn_memory_get_data_handle
(
get
(),
&
handle
),
"could not get native handle"
);
return
handle
;
}
inline
void
set_data_handle
(
void
*
handle
)
const
{
error
::
wrap_c_api
(
mkldnn_memory_set_data_handle
(
get
(),
handle
),
"could not set native handle"
);
}
// Must go away or be private:
static
mkldnn_data_type_t
convert_to_c
(
data_type
adata_type
)
{
return
static_cast
<
mkldnn_data_type_t
>
(
adata_type
);
}
static
mkldnn_memory_format_t
convert_to_c
(
format
aformat
)
{
return
static_cast
<
mkldnn_memory_format_t
>
(
aformat
);
}
};
inline
memory
::
desc
zero_md
()
{
mkldnn_memory_desc_t
zero
;
zero
.
primitive_kind
=
mkldnn_memory
;
return
memory
::
desc
(
zero
);
}
inline
memory
null_memory
(
engine
eng
)
{
mkldnn
::
memory
::
desc
zero
=
zero_md
();
return
memory
({
zero
,
eng
},
nullptr
);
}
inline
bool
is_null_memory
(
const
const_mkldnn_primitive_t
&
aprimitive
)
{
const_mkldnn_primitive_desc_t
aprimitive_pd
;
mkldnn_primitive_get_primitive_desc
(
aprimitive
,
&
aprimitive_pd
);
const
mkldnn_memory_desc_t
*
aprimitive_md
=
mkldnn_primitive_desc_query_memory_d
(
aprimitive_pd
);
return
((
aprimitive_md
!=
nullptr
)
&&
(
aprimitive_md
->
ndims
==
0
));
}
inline
bool
operator
==
(
mkldnn_data_type_t
a
,
memory
::
data_type
b
)
{
return
a
==
memory
::
convert_to_c
(
b
);
}
inline
bool
operator
!=
(
mkldnn_data_type_t
a
,
memory
::
data_type
b
)
{
return
!
(
a
==
b
);
}
inline
bool
operator
==
(
memory
::
data_type
a
,
mkldnn_data_type_t
b
)
{
return
b
==
a
;
}
inline
bool
operator
!=
(
memory
::
data_type
a
,
mkldnn_data_type_t
b
)
{
return
!
(
a
==
b
);
}
inline
bool
operator
==
(
mkldnn_memory_format_t
a
,
memory
::
format
b
)
{
return
a
==
memory
::
convert_to_c
(
b
);
}
inline
bool
operator
!=
(
mkldnn_memory_format_t
a
,
memory
::
format
b
)
{
return
!
(
a
==
b
);
}
inline
bool
operator
==
(
memory
::
format
a
,
mkldnn_memory_format_t
b
)
{
return
b
==
a
;
}
inline
bool
operator
!=
(
memory
::
format
a
,
mkldnn_memory_format_t
b
)
{
return
!
(
a
==
b
);
}
/// @}
/// @addtogroup cpp_api_reorder Reorder
/// @{
struct
reorder
:
public
primitive
{
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
memory
::
primitive_desc
&
input
,
const
memory
::
primitive_desc
&
output
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_reorder_primitive_desc_create
(
&
result
,
input
.
get
(),
output
.
get
()),
"could not create a reorder primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
memory
::
primitive_desc
&
input
,
const
memory
::
primitive_desc
&
output
,
const
primitive_attr
&
aattr
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_reorder_primitive_desc_create_v2
(
&
result
,
input
.
get
(),
output
.
get
(),
aattr
.
get
()),
"could not create a reorder primitive descriptor"
);
reset
(
result
);
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
reorder
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
input
,
const
memory
&
output
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
input
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
output
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a reorder primitive"
);
reset
(
result
);
}
reorder
(
const
primitive
::
at
&
input
,
const
memory
&
output
)
{
auto
input_mpd
=
memory
(
input
).
get_primitive_desc
();
auto
output_mpd
=
output
.
get_primitive_desc
();
auto
reorder_d
=
primitive_desc
(
input_mpd
,
output_mpd
);
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
input
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
output
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
reorder_d
.
get
(),
inputs
,
outputs
),
"could not create a reorder primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_view View
/// @{
struct
view
:
public
primitive
{
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
memory
::
primitive_desc
&
input
,
memory
::
dims
dims
,
memory
::
dims
offsets
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_view_primitive_desc_create
(
&
result
,
input
.
get
(),
&
dims
[
0
],
&
offsets
[
0
]),
"could not create a view primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
view
(
const
primitive_desc
&
view_pd
,
primitive
::
at
input
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
input
.
data
};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
view_pd
.
get
(),
inputs
,
nullptr
),
"could not create a view primitive"
);
reset
(
result
);
}
view
(
memory
input
,
memory
::
dims
dims
,
memory
::
dims
offsets
)
{
mkldnn_primitive_t
result
;
primitive_desc
view_pd
(
input
.
get_primitive_desc
(),
dims
,
offsets
);
mkldnn_primitive_at_t
inputs
[]
=
{
primitive
::
at
(
input
).
data
};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
view_pd
.
get
(),
inputs
,
nullptr
),
"could not create a view primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_concat Concat
/// @{
struct
concat
:
public
primitive
{
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
std
::
vector
<
const_mkldnn_primitive_desc_t
>
cpp_to_c
(
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
std
::
vector
<
const_mkldnn_primitive_desc_t
>
c_api_inputs
;
c_api_inputs
.
reserve
(
inputs
.
size
());
auto
convert_to_c
=
[](
memory
::
primitive_desc
d
)
{
return
d
.
get
();
};
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
c_api_inputs
),
convert_to_c
);
return
c_api_inputs
;
}
primitive_desc
(
const
memory
::
desc
&
output
,
int
concat_dimension
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
error
::
wrap_c_api
(
mkldnn_concat_primitive_desc_create
(
&
result
,
&
output
.
data
,
(
int
)
c_api_inputs
.
size
(),
concat_dimension
,
&
c_api_inputs
[
0
]),
"could not create a concat primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
int
concat_dimension
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
error
::
wrap_c_api
(
mkldnn_concat_primitive_desc_create
(
&
result
,
nullptr
,
(
int
)
c_api_inputs
.
size
(),
concat_dimension
,
&
c_api_inputs
[
0
]),
"could not create a concat primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
concat
(
const
primitive_desc
&
concat_pd
,
std
::
vector
<
primitive
::
at
>
&
inputs
,
const
memory
&
output
)
{
mkldnn_primitive_t
result
;
std
::
vector
<
mkldnn_primitive_at_t
>
p_inputs
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
p_inputs
.
push_back
(
inputs
[
i
].
data
);
const_mkldnn_primitive_t
outputs
[]
=
{
output
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
concat_pd
.
get
(),
&
p_inputs
[
0
],
outputs
),
"could not create a concat primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_sum Sum
/// @{
struct
sum
:
public
primitive
{
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
std
::
vector
<
const_mkldnn_primitive_desc_t
>
cpp_to_c
(
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
std
::
vector
<
const_mkldnn_primitive_desc_t
>
c_api_inputs
;
c_api_inputs
.
reserve
(
inputs
.
size
());
auto
convert_to_c
=
[](
memory
::
primitive_desc
d
)
{
return
d
.
get
();
};
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
c_api_inputs
),
convert_to_c
);
return
c_api_inputs
;
}
primitive_desc
(
const
memory
::
desc
&
output
,
const
std
::
vector
<
float
>
&
scales
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
error
::
wrap_c_api
(
mkldnn_sum_primitive_desc_create
(
&
result
,
&
output
.
data
,
(
int
)
c_api_inputs
.
size
(),
&
scales
[
0
],
&
c_api_inputs
[
0
]),
"could not create a sum primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
std
::
vector
<
float
>
&
scales
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
error
::
wrap_c_api
(
mkldnn_sum_primitive_desc_create
(
&
result
,
nullptr
,
(
int
)
c_api_inputs
.
size
(),
&
scales
[
0
],
&
c_api_inputs
[
0
]),
"could not create a sum primitive descriptor"
);
reset
(
result
);
}
/** @deprecated: api backwards compatibility for double scales type */
MKLDNN_DEPRECATED
primitive_desc
(
const
memory
::
desc
&
output
,
std
::
vector
<
double
>
scale
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
auto
scale_f
=
scale_to_float
(
scale
);
error
::
wrap_c_api
(
mkldnn_sum_primitive_desc_create
(
&
result
,
&
output
.
data
,
(
int
)
c_api_inputs
.
size
(),
&
scale_f
[
0
],
&
c_api_inputs
[
0
]),
"could not create a sum primitive descriptor"
);
reset
(
result
);
}
/** @deprecated: api backwards compatibility for double scales type */
MKLDNN_DEPRECATED
primitive_desc
(
std
::
vector
<
double
>
scale
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
auto
scale_f
=
scale_to_float
(
scale
);
error
::
wrap_c_api
(
mkldnn_sum_primitive_desc_create
(
&
result
,
nullptr
,
(
int
)
c_api_inputs
.
size
(),
&
scale_f
[
0
],
&
c_api_inputs
[
0
]),
"could not create a sum primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
sum
(
const
primitive_desc
&
sum_pd
,
std
::
vector
<
primitive
::
at
>
&
inputs
,
const
memory
&
output
)
{
mkldnn_primitive_t
result
;
std
::
vector
<
mkldnn_primitive_at_t
>
p_inputs
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
p_inputs
.
push_back
(
inputs
[
i
].
data
);
const_mkldnn_primitive_t
outputs
[]
=
{
output
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
sum_pd
.
get
(),
&
p_inputs
[
0
],
outputs
),
"could not create a sum primitive"
);
reset
(
result
);
}
private:
static
std
::
vector
<
float
>
scale_to_float
(
const
std
::
vector
<
double
>
&
vd
)
{
std
::
vector
<
float
>
vf
(
vd
.
size
());
std
::
transform
(
vd
.
begin
(),
vd
.
end
(),
vf
.
begin
(),
[
=
](
double
x
)
{
return
(
float
)
x
;
});
return
vf
;
}
};
/// @}
/// @addtogroup cpp_api_convolution Convolution
/// @{
struct
convolution_forward
:
public
primitive
{
struct
desc
{
mkldnn_convolution_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
&
bias_desc
.
data
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
nullptr
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
&
bias_desc
.
data
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a dilated convolution forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
nullptr
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a dilated convolution forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a convolution forward primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
desc
&
adesc
,
const
primitive_attr
&
aattr
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create_v2
(
&
result
,
&
adesc
.
data
,
aattr
.
get
(),
aengine
.
get
(),
nullptr
),
"could not create a convolution forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
convolution_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
primitive
::
at
&
bias
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
,
bias
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution forward bias primitive"
);
reset
(
result
);
}
convolution_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution forward primitive"
);
reset
(
result
);
}
};
struct
convolution_backward_data
:
public
primitive
{
struct
desc
{
mkldnn_convolution_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_backward_data_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_src_desc
.
data
,
&
weights_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward data descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_backward_data_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_src_desc
.
data
,
&
weights_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward data descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
convolution_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a convolution backward data primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
convolution_backward_data
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution backward data primitive"
);
reset
(
result
);
}
};
struct
convolution_backward_weights
:
public
primitive
{
struct
desc
{
mkldnn_convolution_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward weights descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
nullptr
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward weights descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward weights descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
nullptr
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward weights descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
convolution_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a convolution backward weights primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
convolution_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_weights
,
const
memory
&
diff_bias
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
(),
diff_bias
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution backward weights primitive"
);
reset
(
result
);
}
convolution_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution backward weights primitive"
);
reset
(
result
);
}
};
struct
convolution_relu_forward
:
public
primitive
{
struct
desc
{
mkldnn_convolution_relu_desc_t
data
;
desc
(
const
convolution_forward
::
desc
conv_desc
,
const
float
negative_slope
)
{
error
::
wrap_c_api
(
mkldnn_convolution_relu_desc_init
(
&
data
,
&
conv_desc
.
data
,
negative_slope
),
"could not create a convolution_relu_forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a convolution relu forward descriptor"
);
reset
(
result
);
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
convolution_relu_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
primitive
::
at
&
bias
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
,
bias
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution relu forward primitive"
);
reset
(
result
);
}
convolution_relu_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution relu forward primitive"
);
reset
(
result
);
}
};
/// @}
//
/// @addtogroup cpp_api_deconvolution Deconvolution
/// @{
struct
deconvolution_forward
:
public
primitive
{
struct
desc
{
mkldnn_deconvolution_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
&
bias_desc
.
data
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
nullptr
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a deconvolution forward primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
desc
&
adesc
,
const
primitive_attr
&
aattr
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create_v2
(
&
result
,
&
adesc
.
data
,
aattr
.
get
(),
aengine
.
get
(),
nullptr
),
"could not create a deconvolution forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
deconvolution_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
primitive
::
at
&
bias
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
,
bias
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution forward bias primitive"
);
reset
(
result
);
}
deconvolution_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution forward primitive"
);
reset
(
result
);
}
};
struct
deconvolution_backward_data
:
public
primitive
{
struct
desc
{
mkldnn_deconvolution_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_backward_data_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_src_desc
.
data
,
&
weights_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution backward data descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
deconvolution_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a deconvolution backward data primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
deconvolution_backward_data
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution backward data primitive"
);
reset
(
result
);
}
};
struct
deconvolution_backward_weights
:
public
primitive
{
struct
desc
{
mkldnn_deconvolution_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution backward weights descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
nullptr
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution backward weights descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
deconvolution_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a deconvolution backward weights primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
deconvolution_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_weights
,
const
memory
&
diff_bias
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
(),
diff_bias
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution backward weights primitive"
);
reset
(
result
);
}
deconvolution_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution backward weights primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_lrn LRN
/// @{
struct
lrn_forward
:
public
primitive
{
struct
desc
{
mkldnn_lrn_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
int
local_size
,
float
alpha
,
float
beta
,
float
k
)
{
error
::
wrap_c_api
(
mkldnn_lrn_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
local_size
,
alpha
,
beta
,
k
),
"could not create a lrn forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
int
local_size
,
float
alpha
,
float
beta
)
{
error
::
wrap_c_api
(
mkldnn_lrn_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
local_size
,
alpha
,
beta
,
float
(
1.0
)),
"could not create a lrn forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a lrn forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
ldesc
;
const_mkldnn_primitive_desc_t
const_ldesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
ldesc
,
const_ldesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
ldesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
lrn_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
workspace
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
workspace
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a lrn forward primitive"
);
reset
(
result
);
}
lrn_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a lrn forward primitive"
);
reset
(
result
);
}
};
struct
lrn_backward
:
public
primitive
{
struct
desc
{
mkldnn_lrn_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
data_desc
,
const
memory
::
desc
&
diff_data_desc
,
int
local_size
,
float
alpha
,
float
beta
,
float
k
)
{
error
::
wrap_c_api
(
mkldnn_lrn_backward_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_data_desc
.
data
,
&
data_desc
.
data
,
local_size
,
alpha
,
beta
,
k
),
"could not create a lrn backward descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
data_desc
,
const
memory
::
desc
&
diff_data_desc
,
int
local_size
,
float
alpha
,
float
beta
)
{
error
::
wrap_c_api
(
mkldnn_lrn_backward_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_data_desc
.
data
,
&
data_desc
.
data
,
local_size
,
alpha
,
beta
,
float
(
1.0
)),
"could not create a lrn backward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
lrn_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a backward lrn primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
ldesc
;
const_mkldnn_primitive_desc_t
const_ldesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
ldesc
,
const_ldesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
ldesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
lrn_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
workspace
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
,
workspace
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a lrn backward primitive"
);
reset
(
result
);
}
lrn_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a lrn backward primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_pooling Pooling
/// @{
struct
pooling_forward
:
public
primitive
{
struct
desc
{
mkldnn_pooling_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
kernel
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
kernel
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_pooling_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
kernel
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not init a forward pooling descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a forward pooling primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a workspace primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
pooling_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
nullptr
};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a pooling forward primitive"
);
reset
(
result
);
}
pooling_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
,
const
memory
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
workspace
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a pooling forward primitive"
);
reset
(
result
);
}
};
struct
pooling_backward
:
public
primitive
{
struct
desc
{
mkldnn_pooling_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
&
strides
,
const
memory
::
dims
&
kernel
,
const
memory
::
dims
&
padding_l
,
const
memory
::
dims
&
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
kernel
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_pooling_backward_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_src_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
kernel
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not init a backward pooling descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
pooling_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a backward pooling primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
pooling_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a pooling backward primitive"
);
reset
(
result
);
}
pooling_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
workspace
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
,
workspace
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a pooling backward primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_eltwise Eltwise
/// @{
struct
eltwise_forward
:
public
primitive
{
struct
desc
{
mkldnn_eltwise_desc_t
data
;
template
<
typename
T
>
desc
(
prop_kind
aprop_kind
,
algorithm
alg_kind
,
const
memory
::
desc
&
src_desc
,
T
alpha
=
0
,
T
beta
=
0
)
{
error
::
wrap_c_api
(
mkldnn_eltwise_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
mkldnn
::
convert_to_c
(
alg_kind
),
&
src_desc
.
data
,
static_cast
<
float
>
(
alpha
),
static_cast
<
float
>
(
beta
)),
"could not create a eltwise forward descriptor"
);
}
/** @deprecated: api backward compatibility for relu */
template
<
typename
T
>
MKLDNN_DEPRECATED
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
src_desc
,
T
negative_slope
)
:
desc
(
aprop_kind
,
eltwise_relu
,
src_desc
,
negative_slope
)
{}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a eltwise forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
eltwise_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a eltwise forward primitive"
);
reset
(
result
);
}
};
typedef
eltwise_forward
relu_forward
;
struct
eltwise_backward
:
public
primitive
{
struct
desc
{
mkldnn_eltwise_desc_t
data
;
template
<
typename
T
>
desc
(
algorithm
alg_kind
,
const
memory
::
desc
&
diff_data_desc
,
const
memory
::
desc
&
data_desc
,
T
alpha
=
0
,
T
beta
=
0
)
{
error
::
wrap_c_api
(
mkldnn_eltwise_backward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
alg_kind
),
&
diff_data_desc
.
data
,
&
data_desc
.
data
,
static_cast
<
float
>
(
alpha
),
static_cast
<
float
>
(
beta
)),
"could not create a eltwise backward descriptor"
);
}
/** @deprecated: api backward compatibility for relu */
template
<
typename
T
>
MKLDNN_DEPRECATED
desc
(
const
memory
::
desc
&
diff_data_desc
,
const
memory
::
desc
&
data_desc
,
T
negative_slope
)
:
desc
(
eltwise_relu
,
diff_data_desc
,
data_desc
,
negative_slope
)
{}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
eltwise_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a eltwise backward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
eltwise_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a eltwise backward primitive"
);
reset
(
result
);
}
};
typedef
eltwise_backward
relu_backward
;
/// @}
/// @addtogroup cpp_api_softmax Softmax
/// @{
struct
softmax_forward
:
public
primitive
{
struct
desc
{
mkldnn_softmax_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
data_desc
,
int
softmax_axis
)
{
error
::
wrap_c_api
(
mkldnn_softmax_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
data_desc
.
data
,
softmax_axis
),
"could not create a softmax forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a softmax forward primitive descriptor"
);
reset
(
result
);
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
softmax_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a softmax forward primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_batch_norm Batch normalization
/// @{
struct
batch_normalization_forward
:
public
primitive
{
struct
desc
{
mkldnn_batch_normalization_desc_t
data
;
template
<
typename
T
>
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
src_desc
,
T
epsilon
,
unsigned
flags
)
{
error
::
wrap_c_api
(
mkldnn_batch_normalization_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
src_desc
.
data
,
static_cast
<
float
>
(
epsilon
),
flags
),
"could not create a batch normalization forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a batch normalization forward "
"primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
desc
&
adesc
,
const
primitive_attr
&
aattr
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create_v2
(
&
result
,
&
adesc
.
data
,
aattr
.
get
(),
aengine
.
get
(),
nullptr
),
"could not create a batch normalization forward "
"primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
mean_primitive_desc
()
const
{
memory
::
primitive_desc
aprimitive_desc
;
mkldnn_primitive_desc_t
bndesc
;
mkldnn_batch_normalization_desc_t
*
p
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_query
(
get
(),
mkldnn
::
convert_to_c
(
batch_normalization_d
),
0
,
&
p
),
"could not get a batch-normalization descriptor"
);
const_mkldnn_primitive_desc_t
const_bndesc
=
(
p
->
flags
&
use_global_stats
)
?
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
1
)
:
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a mean primitive descriptor"
);
aprimitive_desc
.
reset
(
bndesc
);
return
aprimitive_desc
;
}
memory
::
primitive_desc
variance_primitive_desc
()
const
{
memory
::
primitive_desc
aprimitive_desc
;
mkldnn_primitive_desc_t
bndesc
;
mkldnn_batch_normalization_desc_t
*
p
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_query
(
get
(),
mkldnn
::
convert_to_c
(
batch_normalization_d
),
0
,
&
p
),
"could not get a batch-normalization descriptor"
);
const_mkldnn_primitive_desc_t
const_bndesc
=
(
p
->
flags
&
use_global_stats
)
?
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
2
)
:
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a variance primitive descriptor"
);
aprimitive_desc
.
reset
(
bndesc
);
return
aprimitive_desc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
/// @warning batch_normalization_forward has 2 constructors with very
/// similar signatures:
/// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
/// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
/// The only way to distinguish between those is to explicitly
/// cast all input parameters to their type, i.e. to
/// const primitive:at &.
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
,
const
memory
&
mean
,
const
memory
&
variance
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
mean
.
get
(),
variance
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
,
const
memory
&
mean
,
const
memory
&
variance
,
const
memory
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
mean
.
get
(),
variance
.
get
(),
workspace
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
,
const
memory
&
mean
,
const
memory
&
variance
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
mean
.
get
(),
variance
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
/// @warning batch_normalization_forward has 2 constructors with very
/// similar signatures:
/// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
/// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
/// The only way to distinguish between those is to explicitly
/// cast all input parameters to their type, i.e. to
/// const primitive:at &.
/// @note to make users' experience a little bit better this constructor
/// checks if whether parameters match corresponding primitive
/// descriptor, and if they are not -- call the other (proper)
/// constructor. Yeah, this is still very ugly...
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
,
const
memory
&
mean
,
const
memory
&
variance
,
const
memory
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[
2
]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[
4
]
=
{
dst
.
get
(),
mean
.
get
(),
variance
.
get
(),
workspace
.
get
()};
if
(
1
)
{
// check whether this is the `wrong` constructor
const
int
n_inputs_expected
=
mkldnn_primitive_desc_query_s32
(
aprimitive_desc
.
get
(),
mkldnn_query_num_of_inputs_s32
,
0
);
const
int
n_outputs_expected
=
mkldnn_primitive_desc_query_s32
(
aprimitive_desc
.
get
(),
mkldnn_query_num_of_outputs_s32
,
0
);
if
(
n_inputs_expected
==
2
&&
n_outputs_expected
==
3
)
{
// shift parameters, get rid of workspace, and add weights...
auto
_weights
=
dst
;
inputs
[
1
]
=
{
_weights
.
get
(),
0
};
auto
_dst
=
mean
,
_mean
=
variance
,
_variance
=
workspace
;
outputs
[
0
]
=
_dst
.
get
();
outputs
[
1
]
=
_mean
.
get
();
outputs
[
2
]
=
_variance
.
get
();
outputs
[
3
]
=
nullptr
;
}
}
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
};
struct
batch_normalization_backward
:
public
primitive
{
struct
desc
{
mkldnn_batch_normalization_desc_t
data
;
template
<
typename
T
>
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
diff_data_desc
,
const
memory
::
desc
&
data_desc
,
T
epsilon
,
unsigned
flags
)
{
error
::
wrap_c_api
(
mkldnn_batch_normalization_backward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
diff_data_desc
.
data
,
&
data_desc
.
data
,
static_cast
<
float
>
(
epsilon
),
flags
),
"could not create a batch normalization backward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
batch_normalization_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a batch normalization backward primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a diff_weights primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
mean_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a mean primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
variance_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a variance primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
// Prop_kind == backward
batch_normalization_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights
,
const
memory
&
diff_src
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
diff_dst
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
(),
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization backward primitive"
);
reset
(
result
);
}
// Prop_kind == backward (+ws)
batch_normalization_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights
,
const
primitive
::
at
&
workspace
,
const
memory
&
diff_src
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
diff_dst
.
data
,
weights
.
data
,
workspace
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
(),
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization backward primitive"
);
reset
(
result
);
}
// Prop_kind == backward_data (+ws or +weights)
/// @warning This constructor works for backward_data propagation
/// - w/ weights but w/o workspace, or
/// - w/ workspace but w/o weights
batch_normalization_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights_or_workspace
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
diff_dst
.
data
,
weights_or_workspace
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization backward primitive"
);
reset
(
result
);
}
// Prop_kind == backward_data
batch_normalization_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization backward primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_inner_product Inner Product
/// @{
struct
inner_product_forward
:
public
primitive
{
struct
desc
{
mkldnn_inner_product_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
&
bias_desc
.
data
,
&
dst_desc
.
data
),
"could not create a inner product forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
nullptr
,
&
dst_desc
.
data
),
"could not create a inner product forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a inner product forward primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
desc
&
adesc
,
const
primitive_attr
&
aattr
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create_v2
(
&
result
,
&
adesc
.
data
,
aattr
.
get
(),
aengine
.
get
(),
nullptr
),
"could not create a inner product "
"forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
inner_product_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
weights
,
const
primitive
::
at
&
bias
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
,
bias
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product forward primitive"
);
reset
(
result
);
}
inner_product_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product forward primitive"
);
reset
(
result
);
}
};
struct
inner_product_backward_data
:
public
primitive
{
struct
desc
{
mkldnn_inner_product_desc_t
data
;
desc
(
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
diff_dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_backward_data_desc_init
(
&
data
,
&
diff_src_desc
.
data
,
&
weights_desc
.
data
,
&
diff_dst_desc
.
data
),
"could not create a inner product backward data descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
inner_product_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a inner product backward data primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff dst primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
inner_product_backward_data
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
weights
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product backward data primitive"
);
reset
(
result
);
}
};
struct
inner_product_backward_weights
:
public
primitive
{
struct
desc
{
mkldnn_inner_product_desc_t
data
;
desc
(
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_backward_weights_desc_init
(
&
data
,
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_desc
.
data
),
"could not create a inner product backward weights descriptor"
);
}
desc
(
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_backward_weights_desc_init
(
&
data
,
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
nullptr
,
&
diff_dst_desc
.
data
),
"could not create a inner product backward weights descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
inner_product_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a inner product backward weights primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff dst primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
inner_product_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
diff_dst
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product backward weights primitive"
);
reset
(
result
);
}
inner_product_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
diff_dst
,
const
memory
&
diff_weights
,
const
memory
&
diff_bias
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
(),
diff_bias
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product backward weights primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_rnn RNN
/// @{
struct
rnn_cell
{
struct
desc
{
mkldnn_rnn_cell_desc_t
c_rnn_cell_
;
desc
(
algorithm
kind
,
algorithm
activation_f
)
{
error
::
wrap_c_api
(
mkldnn_rnn_cell_desc_init
(
&
c_rnn_cell_
,
mkldnn
::
convert_to_c
(
kind
),
mkldnn
::
convert_to_c
(
activation_f
),
0U
,
0
,
0
),
"could not init an rnn cell descriptor"
);
}
desc
(
algorithm
kind
)
:
desc
(
kind
,
algorithm
::
algorithm_undef
)
{}
operator
const
mkldnn_rnn_cell_desc_t
*
()
const
{
return
&
c_rnn_cell_
;
}
algorithm
get_cell_kind
()
const
{
return
algorithm
(
c_rnn_cell_
.
cell_kind
);
}
algorithm
get_activation
()
const
{
return
algorithm
(
c_rnn_cell_
.
activation_kind
);
}
float
get_alpha
()
const
{
return
c_rnn_cell_
.
alpha
;
}
void
set_alpha
(
float
alpha
)
{
c_rnn_cell_
.
flags
|=
mkldnn_rnn_cell_with_relu
;
c_rnn_cell_
.
alpha
=
alpha
;
}
float
get_clipping
()
const
{
return
c_rnn_cell_
.
clipping
;
}
void
set_clipping
(
float
clipping
)
{
c_rnn_cell_
.
flags
|=
mkldnn_rnn_cell_with_clipping
;
c_rnn_cell_
.
clipping
=
clipping
;
}
int
get_gates_count
()
const
{
return
mkldnn_rnn_cell_get_gates_count
(
&
c_rnn_cell_
);
}
int
get_state_count
()
const
{
return
mkldnn_rnn_cell_get_states_count
(
&
c_rnn_cell_
);
}
};
};
struct
rnn_forward
:
public
primitive
{
struct
desc
{
mkldnn_rnn_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
rnn_cell
::
desc
cell
,
const
rnn_direction
direction
,
const
memory
::
desc
&
src_layer_desc
,
const
memory
::
desc
&
src_iter_desc
,
const
memory
::
desc
&
weights_layer_desc
,
const
memory
::
desc
&
weights_iter_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_layer_desc
,
const
memory
::
desc
&
dst_iter_desc
)
{
error
::
wrap_c_api
(
mkldnn_rnn_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
cell
,
mkldnn
::
convert_to_c
(
direction
),
&
src_layer_desc
.
data
,
&
src_iter_desc
.
data
,
&
weights_layer_desc
.
data
,
&
weights_iter_desc
.
data
,
&
bias_desc
.
data
,
&
dst_layer_desc
.
data
,
&
dst_iter_desc
.
data
),
"could not create an RNN forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create an RNN forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone an src layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
src_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src iter primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_src_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
ldesc
;
const_mkldnn_primitive_desc_t
const_ldesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
ldesc
,
const_ldesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
ldesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last iteration primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
rnn_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src_layer
,
const
primitive
::
at
&
src_iter
,
const
primitive
::
at
&
weights_layer
,
const
primitive
::
at
&
weights_iter
,
const
primitive
::
at
&
bias
,
const
memory
&
dst_layer
,
const
memory
&
dst_iter
,
const
memory
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[
5
];
const_mkldnn_primitive_t
outputs
[
3
];
int
idx
=
0
;
inputs
[
idx
++
]
=
src_layer
.
data
;
if
(
!
is_null_memory
(
src_iter
.
data
.
primitive
))
inputs
[
idx
++
]
=
src_iter
.
data
;
inputs
[
idx
++
]
=
weights_layer
.
data
;
inputs
[
idx
++
]
=
weights_iter
.
data
;
if
(
!
is_null_memory
(
bias
.
data
.
primitive
))
inputs
[
idx
++
]
=
bias
.
data
;
idx
=
0
;
outputs
[
idx
++
]
=
dst_layer
.
get
();
if
(
!
is_null_memory
(
dst_iter
.
get
()))
outputs
[
idx
++
]
=
dst_iter
.
get
();
if
(
!
is_null_memory
(
workspace
.
get
()))
outputs
[
idx
++
]
=
workspace
.
get
();
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create an RNN forward primitive"
);
reset
(
result
);
}
};
struct
rnn_backward
:
public
primitive
{
struct
desc
{
mkldnn_rnn_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
rnn_cell
::
desc
cell
,
const
rnn_direction
direction
,
const
memory
::
desc
&
src_layer_desc
,
const
memory
::
desc
&
src_iter_desc
,
const
memory
::
desc
&
weights_layer_desc
,
const
memory
::
desc
&
weights_iter_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_layer_desc
,
const
memory
::
desc
&
dst_iter_desc
,
const
memory
::
desc
&
diff_src_layer_desc
,
const
memory
::
desc
&
diff_src_iter_desc
,
const
memory
::
desc
&
diff_weights_layer_desc
,
const
memory
::
desc
&
diff_weights_iter_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_layer_desc
,
const
memory
::
desc
&
diff_dst_iter_desc
)
{
error
::
wrap_c_api
(
mkldnn_rnn_backward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
cell
,
mkldnn
::
convert_to_c
(
direction
),
&
src_layer_desc
.
data
,
&
src_iter_desc
.
data
,
&
weights_layer_desc
.
data
,
&
weights_iter_desc
.
data
,
&
bias_desc
.
data
,
&
dst_layer_desc
.
data
,
&
dst_iter_desc
.
data
,
&
diff_src_layer_desc
.
data
,
&
diff_src_iter_desc
.
data
,
&
diff_weights_layer_desc
.
data
,
&
diff_weights_iter_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_layer_desc
.
data
,
&
diff_dst_iter_desc
.
data
),
"could not create an RNN backward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create an RNN backward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone an src layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
src_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src iter primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last iteration primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_src_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone an src_layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_src_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src iter primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last iteration primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
ldesc
;
const_mkldnn_primitive_desc_t
const_ldesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
ldesc
,
const_ldesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
ldesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
// With last iteration (with and without input src_iter)
rnn_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src_layer
,
const
primitive
::
at
&
src_iter
,
const
primitive
::
at
&
weights_layer
,
const
primitive
::
at
&
weights_iter
,
const
primitive
::
at
&
bias
,
const
primitive
::
at
&
dst_layer
,
const
primitive
::
at
&
dst_iter
,
const
memory
&
diff_src_layer
,
const
memory
&
diff_src_iter
,
const
memory
&
diff_weights_layer
,
const
memory
&
diff_weights_iter
,
const
memory
&
diff_bias
,
const
primitive
::
at
&
diff_dst_layer
,
const
primitive
::
at
&
diff_dst_iter
,
const
primitive
::
at
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[
10
];
const_mkldnn_primitive_t
outputs
[
5
];
int
idx
=
0
;
inputs
[
idx
]
=
src_layer
.
data
;
if
(
!
is_null_memory
(
src_iter
.
data
.
primitive
))
inputs
[
idx
++
]
=
src_iter
.
data
;
inputs
[
idx
++
]
=
weights_layer
.
data
;
inputs
[
idx
++
]
=
weights_iter
.
data
;
if
(
!
is_null_memory
(
bias
.
data
.
primitive
))
inputs
[
idx
++
]
=
bias
.
data
;
inputs
[
idx
]
=
dst_layer
.
data
;
if
(
!
is_null_memory
(
dst_iter
.
data
.
primitive
))
inputs
[
idx
++
]
=
dst_iter
.
data
;
inputs
[
idx
]
=
diff_dst_layer
.
data
;
if
(
!
is_null_memory
(
diff_dst_iter
.
data
.
primitive
))
inputs
[
idx
++
]
=
diff_dst_iter
.
data
;
inputs
[
idx
]
=
workspace
.
data
;
idx
=
0
;
outputs
[
idx
]
=
diff_src_layer
.
get
();
if
(
!
is_null_memory
(
diff_src_iter
.
get
()))
outputs
[
idx
++
]
=
diff_src_iter
.
get
();
outputs
[
idx
]
=
diff_weights_layer
.
get
();
outputs
[
idx
]
=
diff_weights_iter
.
get
();
if
(
!
is_null_memory
(
diff_bias
.
get
()))
outputs
[
idx
]
=
diff_bias
.
get
();
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create an RNN backward primitive"
);
reset
(
result
);
}
};
/// @}
/// @} Primitives
/// @addtogroup cpp_api_stream Stream
/// @{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_stream_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_stream_destroy
;
};
#endif
struct
stream
:
public
handle
<
mkldnn_stream_t
>
{
using
handle
::
handle
;
enum
kind
{
any
=
mkldnn_stream_kind_t
::
mkldnn_any_stream
,
eager
=
mkldnn_stream_kind_t
::
mkldnn_eager
,
lazy
=
mkldnn_stream_kind_t
::
mkldnn_lazy
};
static
mkldnn_stream_kind_t
convert_to_c
(
kind
akind
)
{
return
static_cast
<
mkldnn_stream_kind_t
>
(
akind
);
}
/// Constructs a stream.
stream
(
kind
akind
)
{
mkldnn_stream_t
astream
;
error
::
wrap_c_api
(
mkldnn_stream_create
(
&
astream
,
convert_to_c
(
akind
)),
"could not create a stream"
);
reset
(
astream
);
}
/// Submits a vector of primitives to a stream for computations.
///
/// @param primitives The vector of primitives to submit.
/// @returns The stream.
stream
&
submit
(
std
::
vector
<
primitive
>
primitives
)
{
// TODO: find a proper way to convert vector<primitive> to
// vector<mkldnn_primitive_t>
if
(
primitives
.
size
()
==
0
)
return
*
this
;
std
::
vector
<
mkldnn_primitive_t
>
c_api_primitives
;
c_api_primitives
.
reserve
(
primitives
.
size
());
auto
convert_to_c
=
[](
primitive
p
)
{
return
p
.
get
();
};
std
::
transform
(
primitives
.
begin
(),
primitives
.
end
(),
std
::
back_inserter
(
c_api_primitives
),
convert_to_c
);
mkldnn_primitive_t
c_api_error_primitive
;
error
::
wrap_c_api
(
mkldnn_stream_submit
(
get
(),
c_api_primitives
.
size
(),
&
c_api_primitives
[
0
],
&
c_api_error_primitive
),
"could not submit primitives to a stream"
,
&
c_api_error_primitive
);
return
*
this
;
}
/// Waits for all computations submitted to the stream to complete.
///
/// @param block Specifies whether the operation should wait indefinitely or
/// return
/// immediately.
/// @returns @c true if all computations completed.
/// @returns @c false if not all computations completed.
bool
wait
(
bool
block
=
true
)
{
mkldnn_primitive_t
c_api_error_primitive
;
mkldnn_status_t
status
=
mkldnn_stream_wait
(
get
(),
block
,
&
c_api_error_primitive
);
if
(
status
!=
mkldnn_success
&&
status
!=
mkldnn_try_again
)
error
::
wrap_c_api
(
status
,
"could not wait on a stream"
,
&
c_api_error_primitive
);
return
(
status
==
mkldnn_success
);
}
stream
&
rerun
()
{
mkldnn_primitive_t
c_api_error_primitive
;
error
::
wrap_c_api
(
mkldnn_stream_rerun
(
get
(),
&
c_api_error_primitive
),
"could not rerun a stream"
,
&
c_api_error_primitive
);
return
*
this
;
}
};
/// @}
/// @} C++ API
}
// namespace mkldnn
#endif
python/paddle/fluid/tests/unittests/test_network_with_dtype.py
浏览文件 @
c72a4f4d
...
...
@@ -27,12 +27,15 @@ class TestNetWithDtype(unittest.TestCase):
def
set_network
(
self
):
self
.
dtype
=
"float64"
self
.
init_dtype
()
main
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
):
self
.
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
self
.
dtype
)
self
.
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
self
.
dtype
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
self
.
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
self
.
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
self
.
program
=
main
self
.
fetch_list
=
[
avg_cost
]
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
...
...
@@ -45,7 +48,7 @@ class TestNetWithDtype(unittest.TestCase):
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
data
in
train_reader
():
exe
.
run
(
fluid
.
default_main_program
()
,
exe
.
run
(
self
.
program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
self
.
fetch_list
)
# the main program is runable, the datatype is fully supported
...
...
@@ -68,7 +71,7 @@ class TestNetWithDtype(unittest.TestCase):
# TODO(dzhwinter): make sure the fp16 is runable
# class TestFloat16(
SimpleNet
):
# class TestFloat16(
TestNetWithDtype
):
# def init_dtype(self):
# self.dtype = "float16"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录