Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4d2a2e75
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4d2a2e75
编写于
5月 17, 2018
作者:
B
baiyfbupt
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into develop
上级
728062a5
d0a62bfc
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
133 addition
and
4319 deletion
+133
-4319
cmake/external/boost.cmake
cmake/external/boost.cmake
+1
-1
cmake/external/eigen.cmake
cmake/external/eigen.cmake
+2
-1
cmake/external/mkldnn.cmake
cmake/external/mkldnn.cmake
+1
-3
cmake/external/mklml.cmake
cmake/external/mklml.cmake
+1
-1
cmake/inference_lib.cmake
cmake/inference_lib.cmake
+8
-0
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+7
-5
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+32
-7
paddle/scripts/paddle_docker_build.sh
paddle/scripts/paddle_docker_build.sh
+1
-0
patches/mkldnn.hpp
patches/mkldnn.hpp
+0
-4252
python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py
...-level-api/recognize_digits/test_recognize_digits_conv.py
+12
-12
python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py
...h-level-api/recognize_digits/test_recognize_digits_mlp.py
+4
-10
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+63
-26
tools/timeline.py
tools/timeline.py
+1
-1
未找到文件。
cmake/external/boost.cmake
浏览文件 @
4d2a2e75
...
...
@@ -24,7 +24,7 @@ set(BOOST_PROJECT "extern_boost")
# So we use 1.41.0 here.
set
(
BOOST_VER
"1.41.0"
)
set
(
BOOST_TAR
"boost_1_41_0"
)
set
(
BOOST_URL
"http://paddlepaddledeps.
bj
.bcebos.com/
${
BOOST_TAR
}
.tar.gz"
)
set
(
BOOST_URL
"http://paddlepaddledeps.
cdn
.bcebos.com/
${
BOOST_TAR
}
.tar.gz"
)
set
(
BOOST_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/boost
)
set
(
BOOST_DOWNLOAD_DIR
"
${
BOOST_SOURCES_DIR
}
/src/
${
BOOST_PROJECT
}
"
)
set
(
BOOST_INCLUDE_DIR
"
${
BOOST_DOWNLOAD_DIR
}
/
${
BOOST_TAR
}
"
CACHE PATH
"boost include directory."
FORCE
)
...
...
cmake/external/eigen.cmake
浏览文件 @
4d2a2e75
...
...
@@ -21,11 +21,12 @@ else()
ExternalProject_Add
(
extern_eigen3
${
EXTERNAL_PROJECT_LOG_ARGS
}
GIT_REPOSITORY
"https://github.com/
RLovelett/eigen.git
"
GIT_REPOSITORY
"https://github.com/
eigenteam/eigen-git-mirror
"
# eigen on cuda9.1 missing header of math_funtions.hpp
# https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen
GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c
PREFIX
${
EIGEN_SOURCE_DIR
}
DOWNLOAD_NAME
"eigen"
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
...
...
cmake/external/mkldnn.cmake
浏览文件 @
4d2a2e75
...
...
@@ -53,11 +53,9 @@ ExternalProject_Add(
${
EXTERNAL_PROJECT_LOG_ARGS
}
DEPENDS
${
MKLDNN_DEPENDS
}
GIT_REPOSITORY
"https://github.com/01org/mkl-dnn.git"
GIT_TAG
"
v0.14
"
GIT_TAG
"
db3424ad44901513c03a1ea31ccaacdf633fbe9f
"
PREFIX
${
MKLDNN_SOURCES_DIR
}
UPDATE_COMMAND
""
# Patch MKLDNN to compile with gcc 4.8, the related issue is in intel/mkl-dnn#237.
PATCH_COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
CMAKE_CURRENT_SOURCE_DIR
}
/patches/mkldnn.hpp
${
MKLDNN_SOURCES_DIR
}
/src/extern_mkldnn/include/mkldnn.hpp
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
MKLDNN_INSTALL_DIR
}
CMAKE_ARGS -DCMAKE_BUILD_TYPE=
${
CMAKE_BUILD_TYPE
}
CMAKE_ARGS -DMKLROOT=
${
MKLML_ROOT
}
...
...
cmake/external/mklml.cmake
浏览文件 @
4d2a2e75
...
...
@@ -28,7 +28,7 @@ INCLUDE(ExternalProject)
SET
(
MKLML_PROJECT
"extern_mklml"
)
SET
(
MKLML_VER
"mklml_lnx_2018.0.3.20180406"
)
SET
(
MKLML_URL
"http://paddlepaddledeps.
bj
.bcebos.com/
${
MKLML_VER
}
.tgz"
)
SET
(
MKLML_URL
"http://paddlepaddledeps.
cdn
.bcebos.com/
${
MKLML_VER
}
.tgz"
)
SET
(
MKLML_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/mklml"
)
SET
(
MKLML_DOWNLOAD_DIR
"
${
MKLML_SOURCE_DIR
}
/src/
${
MKLML_PROJECT
}
"
)
SET
(
MKLML_DST_DIR
"mklml"
)
...
...
cmake/inference_lib.cmake
浏览文件 @
4d2a2e75
...
...
@@ -98,6 +98,14 @@ elseif (WITH_MKLML)
)
endif
()
if
(
WITH_MKLDNN
)
set
(
dst_dir
"
${
CMAKE_INSTALL_PREFIX
}
/third_party/install/mkldnn"
)
copy
(
mkldnn_lib
SRCS
${
MKLDNN_INC_DIR
}
${
MKLDNN_SHARED_LIB
}
DSTS
${
dst_dir
}
${
dst_dir
}
/lib
)
endif
()
if
(
NOT MOBILE_INFERENCE AND NOT RPI
)
set
(
dst_dir
"
${
CMAKE_INSTALL_PREFIX
}
/third_party/install/snappy"
)
copy
(
snappy_lib
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
4d2a2e75
...
...
@@ -186,11 +186,7 @@ endif()
add_subdirectory
(
detail
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_GPU
)
op_library
(
gen_nccl_id_op DEPS nccl_common
)
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
endif
()
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
op_library
(
send_op DEPS
${
DISTRIBUTE_DEPS
}
)
...
...
@@ -208,6 +204,12 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
send_recv_op_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor
)
cc_test
(
test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor
)
if
(
WITH_GPU
)
op_library
(
gen_nccl_id_op DEPS nccl_common sendrecvop_grpc
)
set_source_files_properties
(
gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
endif
()
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op
)
endif
()
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
4d2a2e75
...
...
@@ -20,19 +20,15 @@
#=================================================
function
print_usage
()
{
RED
=
'\033[0;31m'
BLUE
=
'\033[0;34m'
BOLD
=
'\033[1m'
NONE
=
'\033[0m'
echo
-e
"
\n
${
RED
}
Usage
${
NONE
}
:
${
BOLD
}
$
0
${
NONE
}
[OPTION]"
${
BOLD
}$
{
SCRIPT_NAME
}
${
NONE
}
[OPTION]"
echo
-e
"
\n
${
RED
}
Options
${
NONE
}
:
${
BLUE
}
build
${
NONE
}
: run build for x86 platform
${
BLUE
}
build_android
${
NONE
}
: run build for android platform
${
BLUE
}
build_ios
${
NONE
}
: run build for ios platform
${
BLUE
}
test
${
NONE
}
: run all unit tests
${
BLUE
}
single_test
${
NONE
}
: run a single unit test
${
BLUE
}
bind_test
${
NONE
}
: parallel tests bind to different GPU
${
BLUE
}
doc
${
NONE
}
: generate paddle documents
${
BLUE
}
html
${
NONE
}
: convert C++ source code into HTML
...
...
@@ -45,7 +41,15 @@ function print_usage() {
}
function
init
()
{
RED
=
'\033[0;31m'
BLUE
=
'\033[0;34m'
BOLD
=
'\033[1m'
NONE
=
'\033[0m'
PADDLE_ROOT
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
/../../"
&&
pwd
)
"
if
[
-z
"
${
SCRIPT_NAME
}
"
]
;
then
SCRIPT_NAME
=
$0
fi
}
function
cmake_gen
()
{
...
...
@@ -91,7 +95,6 @@ function cmake_gen() {
-DWITH_AVX=
${
WITH_AVX
:-
OFF
}
-DWITH_GOLANG=
${
WITH_GOLANG
:-
OFF
}
-DCUDA_ARCH_NAME=
${
CUDA_ARCH_NAME
:-
All
}
-DWITH_SWIG_PY=ON
-DWITH_C_API=
${
WITH_C_API
:-
OFF
}
-DWITH_PYTHON=
${
WITH_PYTHON
:-
ON
}
-DWITH_SWIG_PY=
${
WITH_SWIG_PY
:-
ON
}
...
...
@@ -309,6 +312,25 @@ EOF
fi
}
function
single_test
()
{
TEST_NAME
=
$1
if
[
-z
"
${
TEST_NAME
}
"
]
;
then
echo
-e
"
${
RED
}
Usage:
${
NONE
}
"
echo
-e
"
${
BOLD
}${
SCRIPT_NAME
}${
NONE
}
${
BLUE
}
single_test
${
NONE
}
[test_name]"
exit
1
fi
mkdir
-p
${
PADDLE_ROOT
}
/build
cd
${
PADDLE_ROOT
}
/build
if
[
${
WITH_TESTING
:-
ON
}
==
"ON"
]
;
then
cat
<<
EOF
========================================
Running
${
TEST_NAME
}
...
========================================
EOF
ctest
--output-on-failure
-R
${
TEST_NAME
}
fi
}
function
bind_test
()
{
# the number of process to run tests
NUM_PROC
=
6
...
...
@@ -491,6 +513,9 @@ function main() {
test
)
run_test
;;
single_test
)
single_test
$2
;;
bind_test
)
bind_test
;;
...
...
paddle/scripts/paddle_docker_build.sh
浏览文件 @
4d2a2e75
...
...
@@ -63,6 +63,7 @@ EOL
${
DOCKER_CMD
}
run
-it
\
--name
$CONTAINER_ID
\
${
DOCKER_ENV
}
\
-e
SCRIPT_NAME
=
$0
\
-v
$PADDLE_ROOT
:/paddle
\
-v
${
HOME
}
/.ccache:/root/.ccache
\
-w
/paddle
\
...
...
patches/mkldnn.hpp
已删除
100644 → 0
浏览文件 @
728062a5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MKLDNN_HPP
#define MKLDNN_HPP
#ifndef DOXYGEN_SHOULD_SKIP_THIS
#include <stdlib.h>
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include "mkldnn.h"
#endif
namespace
mkldnn
{
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_api_utils Utils
/// @{
/// A class that provides the destructor for an Intel(R) MKL-DNN C handle
template
<
typename
T
>
class
handle_traits
{};
/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
/// can be passed by value. This class enables wrapping:
/// - Newly constructed handles.
/// @n In this case, the constructed handle uses reference counting provided
/// by @p std::shared_ptr with a proper deleter function specified through
/// the @p handle_traits class.
/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
/// example, through #mkldnn_primitive_get_output()).
/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
/// deleter because it is assumed that the handle wrapper for the original
/// object deletes the handle (this model is similar to @p std::weak_ptr).
template
<
typename
T
,
typename
traits
=
handle_traits
<
T
>
>
class
handle
{
private:
std
::
shared_ptr
<
typename
std
::
remove_pointer
<
T
>::
type
>
_data
;
handle
(
const
handle
&&
)
=
delete
;
handle
&
operator
=
(
const
handle
&&
other
)
=
delete
;
protected:
/// Constructs a C handle wrapper.
/// @param t The C handle to wrap.
/// @param weak A flag to specify whether to construct a weak wrapper.
handle
(
T
t
=
0
,
bool
weak
=
false
)
:
_data
(
0
)
{
reset
(
t
,
weak
);
}
bool
operator
==
(
const
T
other
)
const
{
return
other
==
_data
.
get
();
}
bool
operator
!=
(
const
T
other
)
const
{
return
!
(
*
this
==
other
);
}
public:
handle
(
const
handle
&
other
)
:
_data
(
other
.
_data
)
{}
handle
&
operator
=
(
const
handle
&
other
)
{
_data
=
other
.
_data
;
return
*
this
;
}
/// Resets the value of a C handle.
/// @param t The new value of the C handle.
/// @param weak A flag to specify whether the wrapper should be weak.
void
reset
(
T
t
,
bool
weak
=
false
)
{
auto
dummy_destructor
=
[](
T
)
{
return
decltype
(
traits
::
destructor
(
0
))(
0
);
};
_data
.
reset
(
t
,
weak
?
dummy_destructor
:
traits
::
destructor
);
}
/// Returns the value of the underlying C handle.
T
get
()
const
{
return
_data
.
get
();
}
bool
operator
==
(
const
handle
&
other
)
const
{
return
other
.
_data
.
get
()
==
_data
.
get
();
}
bool
operator
!=
(
const
handle
&
other
)
const
{
return
!
(
*
this
==
other
);
}
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_primitive_desc_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_primitive_desc_destroy
;
};
template
<
>
struct
handle_traits
<
mkldnn_primitive_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_primitive_destroy
;
};
#endif
/// Base class for all computational primitives.
class
primitive
:
public
handle
<
mkldnn_primitive_t
>
{
friend
struct
error
;
friend
struct
stream
;
friend
class
primitive_at
;
using
handle
::
handle
;
public:
/// A proxy to C primitive kind enum
enum
class
kind
{
undefined_primitive
=
mkldnn_undefined_primitive
,
memory
=
mkldnn_memory
,
view
=
mkldnn_view
,
reorder
=
mkldnn_reorder
,
concat
=
mkldnn_concat
,
concat_inplace
=
mkldnn_concat_inplace
,
sum
=
mkldnn_sum
,
convolution
=
mkldnn_convolution
,
deconvolution
=
mkldnn_deconvolution
,
eltwise
=
mkldnn_eltwise
,
relu
=
mkldnn_relu
,
softmax
=
mkldnn_softmax
,
pooling
=
mkldnn_pooling
,
lrn
=
mkldnn_lrn
,
batch_normalization
=
mkldnn_batch_normalization
,
inner_product
=
mkldnn_inner_product
,
convolution_relu
=
mkldnn_convolution_relu
,
rnn
=
mkldnn_rnn
,
};
/// A wrapper structure to specify a particular output of a primitive.
struct
at
{
/// The underlying C API structure.
mkldnn_primitive_at_t
data
;
/// Constructs a wrapper specifying @p aprimitive output with index @p
/// at.
///
/// @param aprimitive The target primitive.
/// @param at The output index.
at
(
const
primitive
&
aprimitive
,
size_t
at
=
0
)
:
data
(
mkldnn_primitive_at
(
aprimitive
.
get
(),
at
))
{}
/// Returns the specified output.
inline
operator
primitive
()
const
;
};
/// Returns the descriptor of the underlying C API primitive
inline
const_mkldnn_primitive_desc_t
get_primitive_desc
()
const
;
// TODO: use the C++ API wrapper structure.
};
inline
mkldnn_primitive_kind_t
convert_to_c
(
primitive
::
kind
akind
)
{
return
static_cast
<
mkldnn_primitive_kind_t
>
(
akind
);
}
/// Intel(R) MKL-DNN exception class.
///
/// This class captures the status returned by the failed C API function, error
/// message, and, optionally, handle of the primitive that caused the error.
struct
error
:
public
std
::
exception
{
mkldnn_status_t
status
;
std
::
string
message
;
primitive
error_primitive
;
/// Constructs an error instance.
///
/// @param astatus The error status returned by the C API.
/// @param amessage The error message.
/// @param aerror_primitive (optional) A C handle of the primitive that
/// caused the error.
error
(
mkldnn_status_t
astatus
,
std
::
string
amessage
,
mkldnn_primitive_t
aerror_primitive
=
0
)
:
status
(
astatus
),
message
(
amessage
),
error_primitive
(
aerror_primitive
,
true
)
{}
/// A convenience function for wrapping calls to the C API. Checks the
/// return status and throws an #error in case of failure.
///
/// @param status The error status returned by the C API.
/// @param message The error message.
/// @param error_primitive (optional) A C handle of the primitive that
/// caused the error.
static
void
wrap_c_api
(
mkldnn_status_t
status
,
std
::
string
message
,
mkldnn_primitive_t
*
error_primitive
=
0
)
{
if
(
status
!=
mkldnn_success
)
{
if
(
nullptr
!=
error_primitive
)
throw
error
(
status
,
message
,
*
error_primitive
);
else
throw
error
(
status
,
message
,
nullptr
);
}
}
};
inline
primitive
::
at
::
operator
primitive
()
const
{
const_mkldnn_primitive_t
output
;
error
::
wrap_c_api
(
mkldnn_primitive_get_output
(
data
.
primitive
,
data
.
output_index
,
&
output
),
"could not get an output primitive"
);
return
primitive
(
const_cast
<
mkldnn_primitive_t
>
(
output
),
true
);
}
const_mkldnn_primitive_desc_t
primitive
::
get_primitive_desc
()
const
{
const_mkldnn_primitive_desc_t
pd
;
error
::
wrap_c_api
(
mkldnn_primitive_get_primitive_desc
(
get
(),
&
pd
),
"could not get primitive descriptor by primitive"
);
return
pd
;
}
/// @}
/// @addtogroup cpp_api_enums Common data types and enumerations
/// @{
enum
round_mode
{
round_nearest
=
mkldnn_round_nearest
,
round_down
=
mkldnn_round_down
,
};
inline
mkldnn_round_mode_t
convert_to_c
(
round_mode
mode
)
{
return
static_cast
<
mkldnn_round_mode_t
>
(
mode
);
}
enum
padding_kind
{
zero
=
mkldnn_padding_zero
};
inline
mkldnn_padding_kind_t
convert_to_c
(
padding_kind
kind
)
{
return
static_cast
<
mkldnn_padding_kind_t
>
(
kind
);
}
enum
prop_kind
{
forward_training
=
mkldnn_forward_training
,
forward_scoring
=
mkldnn_forward_scoring
,
forward_inference
=
mkldnn_forward_inference
,
forward
=
mkldnn_forward
,
backward
=
mkldnn_backward
,
backward_data
=
mkldnn_backward_data
,
backward_weights
=
mkldnn_backward_weights
,
backward_bias
=
mkldnn_backward_bias
};
inline
mkldnn_prop_kind_t
convert_to_c
(
prop_kind
kind
)
{
return
static_cast
<
mkldnn_prop_kind_t
>
(
kind
);
}
enum
algorithm
{
algorithm_undef
=
mkldnn_alg_kind_undef
,
convolution_direct
=
mkldnn_convolution_direct
,
convolution_winograd
=
mkldnn_convolution_winograd
,
deconvolution_direct
=
mkldnn_deconvolution_direct
,
deconvolution_winograd
=
mkldnn_deconvolution_winograd
,
eltwise_relu
=
mkldnn_eltwise_relu
,
eltwise_tanh
=
mkldnn_eltwise_tanh
,
eltwise_elu
=
mkldnn_eltwise_elu
,
eltwise_square
=
mkldnn_eltwise_square
,
eltwise_abs
=
mkldnn_eltwise_abs
,
eltwise_sqrt
=
mkldnn_eltwise_sqrt
,
eltwise_linear
=
mkldnn_eltwise_linear
,
eltwise_bounded_relu
=
mkldnn_eltwise_bounded_relu
,
eltwise_soft_relu
=
mkldnn_eltwise_soft_relu
,
eltwise_logistic
=
mkldnn_eltwise_logistic
,
lrn_across_channels
=
mkldnn_lrn_across_channels
,
lrn_within_channel
=
mkldnn_lrn_within_channel
,
pooling_max
=
mkldnn_pooling_max
,
pooling_avg
=
mkldnn_pooling_avg
,
pooling_avg_include_padding
=
mkldnn_pooling_avg_include_padding
,
pooling_avg_exclude_padding
=
mkldnn_pooling_avg_exclude_padding
,
vanilla_rnn
=
mkldnn_vanilla_rnn
,
vanilla_lstm
=
mkldnn_vanilla_lstm
,
vanilla_gru
=
mkldnn_vanilla_gru
,
};
inline
mkldnn_alg_kind_t
convert_to_c
(
algorithm
aalgorithm
)
{
return
static_cast
<
mkldnn_alg_kind_t
>
(
aalgorithm
);
}
enum
batch_normalization_flag
{
use_global_stats
=
mkldnn_use_global_stats
,
use_scale_shift
=
mkldnn_use_scaleshift
,
omit_stats
=
mkldnn_omit_stats
,
fuse_bn_relu
=
mkldnn_fuse_bn_relu
};
inline
mkldnn_batch_normalization_flag_t
convert_to_c
(
batch_normalization_flag
aflag
)
{
return
static_cast
<
mkldnn_batch_normalization_flag_t
>
(
aflag
);
}
enum
rnn_direction
{
unidirectional_left2right
=
mkldnn_unidirectional_left2right
,
unidirectional_right2left
=
mkldnn_unidirectional_right2left
,
unidirectional
=
mkldnn_unidirectional
,
bidirectional_concat
=
mkldnn_bidirectional_concat
,
bidirectional_sum
=
mkldnn_bidirectional_sum
,
};
inline
mkldnn_rnn_direction_t
convert_to_c
(
rnn_direction
adir
)
{
return
static_cast
<
mkldnn_rnn_direction_t
>
(
adir
);
}
enum
query
{
undef
=
mkldnn_query_undef
,
eengine
=
mkldnn_query_engine
,
primitive_kind
=
mkldnn_query_primitive_kind
,
num_of_inputs_s32
=
mkldnn_query_num_of_inputs_s32
,
num_of_outputs_s32
=
mkldnn_query_num_of_outputs_s32
,
time_estimate_f64
=
mkldnn_query_time_estimate_f64
,
memory_consumption_s64
=
mkldnn_query_memory_consumption_s64
,
impl_info_str
=
mkldnn_query_impl_info_str
,
memory_d
=
mkldnn_query_memory_d
,
convolution_d
=
mkldnn_query_convolution_d
,
deconvolution_d
=
mkldnn_query_deconvolution_d
,
eltwise_d
=
mkldnn_query_eltwise_d
,
relu_d
=
mkldnn_query_relu_d
,
softmax_d
=
mkldnn_query_softmax_d
,
pooling_d
=
mkldnn_query_pooling_d
,
lrn_d
=
mkldnn_query_lrn_d
,
batch_normalization_d
=
mkldnn_query_batch_normalization_d
,
inner_product_d
=
mkldnn_query_inner_product_d
,
convolution_relu_d
=
mkldnn_query_convolution_relu_d
,
rnn_d
=
mkldnn_query_rnn_d
,
input_pd
=
mkldnn_query_input_pd
,
output_pd
=
mkldnn_query_output_pd
,
src_pd
=
mkldnn_query_src_pd
,
diff_src_pd
=
mkldnn_query_diff_src_pd
,
weights_pd
=
mkldnn_query_weights_pd
,
diff_weights_pd
=
mkldnn_query_diff_weights_pd
,
dst_pd
=
mkldnn_query_dst_pd
,
diff_dst_pd
=
mkldnn_query_diff_dst_pd
,
workspace_pd
=
mkldnn_query_workspace_pd
,
};
inline
mkldnn_query_t
convert_to_c
(
query
aquery
)
{
return
static_cast
<
mkldnn_query_t
>
(
aquery
);
}
/// @}
/// @addtogroup cpp_api_attr Attributes
/// @{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_post_ops_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_post_ops_destroy
;
};
#endif
struct
post_ops
:
public
handle
<
mkldnn_post_ops_t
>
{
post_ops
()
{
mkldnn_post_ops_t
result
;
error
::
wrap_c_api
(
mkldnn_post_ops_create
(
&
result
),
"could not create post operation sequence"
);
reset
(
result
);
}
int
len
()
const
{
return
mkldnn_post_ops_len
(
get
());
}
primitive
::
kind
kind
(
int
index
)
const
{
error
::
wrap_c_api
(
index
<
len
()
?
mkldnn_success
:
mkldnn_invalid_arguments
,
"post_ops index is out of range"
);
return
static_cast
<
primitive
::
kind
>
(
mkldnn_post_ops_get_kind
(
get
(),
index
));
}
void
append_sum
(
float
scale
=
1.
)
{
error
::
wrap_c_api
(
mkldnn_post_ops_append_sum
(
get
(),
scale
),
"could not append sum"
);
}
void
get_params_sum
(
int
index
,
float
&
scale
)
const
{
error
::
wrap_c_api
(
mkldnn_post_ops_get_params_sum
(
get
(),
index
,
&
scale
),
"could not get sum params"
);
}
void
append_eltwise
(
float
scale
,
algorithm
alg
,
float
alpha
,
float
beta
)
{
error
::
wrap_c_api
(
mkldnn_post_ops_append_eltwise
(
get
(),
scale
,
convert_to_c
(
alg
),
alpha
,
beta
),
"could not append eltwise"
);
}
void
get_params_eltwise
(
int
index
,
float
&
scale
,
algorithm
&
alg
,
float
&
alpha
,
float
&
beta
)
const
{
mkldnn_alg_kind_t
c_alg
;
error
::
wrap_c_api
(
mkldnn_post_ops_get_params_eltwise
(
get
(),
index
,
&
scale
,
&
c_alg
,
&
alpha
,
&
beta
),
"could not get eltwise params"
);
alg
=
static_cast
<
algorithm
>
(
c_alg
);
}
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_primitive_attr_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_primitive_attr_destroy
;
};
#endif
struct
primitive_attr
:
public
handle
<
mkldnn_primitive_attr_t
>
{
primitive_attr
()
{
mkldnn_primitive_attr_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_attr_create
(
&
result
),
"could not create a primitive attr"
);
reset
(
result
);
}
round_mode
get_int_output_round_mode
()
const
{
mkldnn_round_mode_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_attr_get_int_output_round_mode
(
get
(),
&
result
),
"could not get int output round mode"
);
return
round_mode
(
result
);
}
void
set_int_output_round_mode
(
round_mode
mode
)
{
error
::
wrap_c_api
(
mkldnn_primitive_attr_set_int_output_round_mode
(
get
(),
mkldnn
::
convert_to_c
(
mode
)),
"could not set int output round mode"
);
}
void
get_output_scales
(
int
&
mask
,
std
::
vector
<
float
>
&
scales
)
const
{
int
count
,
c_mask
;
const
float
*
c_scales
;
error
::
wrap_c_api
(
mkldnn_primitive_attr_get_output_scales
(
get
(),
&
count
,
&
c_mask
,
&
c_scales
),
"could not get int output scales"
);
scales
.
resize
(
count
);
mask
=
c_mask
;
for
(
int
c
=
0
;
c
<
count
;
++
c
)
scales
[
c
]
=
c_scales
[
c
];
}
void
set_output_scales
(
int
mask
,
const
std
::
vector
<
float
>
&
scales
)
{
error
::
wrap_c_api
(
mkldnn_primitive_attr_set_output_scales
(
get
(),
(
int
)
scales
.
size
(),
mask
,
&
scales
[
0
]),
"could not set int output scales"
);
}
const
post_ops
get_post_ops
()
const
{
post_ops
result
;
const_mkldnn_post_ops_t
c_result
;
error
::
wrap_c_api
(
mkldnn_primitive_attr_get_post_ops
(
get
(),
&
c_result
),
"could not get post operation sequence"
);
result
.
reset
(
const_cast
<
mkldnn_post_ops_t
>
(
c_result
),
true
);
return
result
;
}
void
set_post_ops
(
post_ops
ops
)
{
error
::
wrap_c_api
(
mkldnn_primitive_attr_set_post_ops
(
get
(),
ops
.
get
()),
"could not set post operation sequence"
);
}
};
/// @}
/// @addtogroup cpp_api_engine Engine
/// @{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_engine_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_engine_destroy
;
};
#endif
/// An execution engine.
struct
engine
:
public
handle
<
mkldnn_engine_t
>
{
friend
class
primitive
;
// gcc bug??? using handle::handle;
/// Kinds of engines
enum
kind
{
/// An unspecified engine
any
=
mkldnn_any_engine
,
/// CPU engine
cpu
=
mkldnn_cpu
,
};
/// Returns the number of engines of a certain kind.
///
/// @param akind The kind of engines to count.
static
size_t
get_count
(
kind
akind
)
{
return
mkldnn_engine_get_count
(
convert_to_c
(
akind
));
}
/// Constructs an engine.
///
/// @param akind The kind of engine to construct.
/// @param index The index of the engine. Must be less than the value
/// returned by #get_count() for this particular kind of engine.
engine
(
kind
akind
,
size_t
index
)
{
mkldnn_engine_t
aengine
;
error
::
wrap_c_api
(
mkldnn_engine_create
(
&
aengine
,
convert_to_c
(
akind
),
index
),
"could not create an engine"
);
reset
(
aengine
);
}
explicit
engine
(
const
mkldnn_engine_t
&
aengine
)
:
handle
(
aengine
,
true
)
{}
engine
(
const
handle
<
mkldnn_primitive_desc_t
>
&
pd
)
{
mkldnn_engine_t
engine_q
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_query
(
pd
.
get
(),
mkldnn
::
convert_to_c
(
eengine
),
0
,
&
engine_q
),
"could not get engine from primitive_desc"
);
reset
(
engine_q
,
true
);
}
template
<
class
primitive_desc
>
static
engine
query
(
const
primitive_desc
&
pd
)
{
mkldnn_engine_t
engine_q
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_query
(
pd
.
get
(),
mkldnn
::
convert_to_c
(
eengine
),
0
,
&
engine_q
),
"could not get engine from primitive_desc"
);
return
engine
(
engine_q
);
}
private:
static
mkldnn_engine_kind_t
convert_to_c
(
kind
akind
)
{
return
static_cast
<
mkldnn_engine_kind_t
>
(
akind
);
}
};
/// @}
/// @addtogroup cpp_api_primitives Primitives
/// @{
/// @addtogroup cpp_api_memory Memory
/// @{
/// Memory primitive that describes the data.
struct
memory
:
public
primitive
{
private:
std
::
shared_ptr
<
char
>
_handle
;
public:
typedef
std
::
vector
<
std
::
remove_extent
<
mkldnn_dims_t
>::
type
>
dims
;
template
<
typename
T
>
static
void
validate_dims
(
std
::
vector
<
T
>
v
)
{
if
(
v
.
size
()
>
TENSOR_MAX_DIMS
)
throw
error
(
mkldnn_invalid_arguments
,
"invalid dimensions"
);
}
/// Data type specification. See #mkldnn_data_type_t for a detailed
/// description.
enum
data_type
{
data_undef
=
mkldnn_data_type_undef
,
f32
=
mkldnn_f32
,
s32
=
mkldnn_s32
,
s16
=
mkldnn_s16
,
s8
=
mkldnn_s8
,
u8
=
mkldnn_u8
,
};
/// Memory format specification. See #mkldnn_memory_format_t
/// for a detailed description.
enum
format
{
format_undef
=
mkldnn_format_undef
,
any
=
mkldnn_any
,
blocked
=
mkldnn_blocked
,
x
=
mkldnn_x
,
nc
=
mkldnn_nc
,
nchw
=
mkldnn_nchw
,
nhwc
=
mkldnn_nhwc
,
chwn
=
mkldnn_chwn
,
nChw8c
=
mkldnn_nChw8c
,
nChw16c
=
mkldnn_nChw16c
,
ncdhw
=
mkldnn_ncdhw
,
ndhwc
=
mkldnn_ndhwc
,
nCdhw16c
=
mkldnn_nCdhw16c
,
oi
=
mkldnn_oi
,
io
=
mkldnn_io
,
oihw
=
mkldnn_oihw
,
ihwo
=
mkldnn_ihwo
,
hwio
=
mkldnn_hwio
,
oidhw
=
mkldnn_oidhw
,
OIdhw16i16o
=
mkldnn_OIdhw16i16o
,
OIdhw16o16i
=
mkldnn_OIdhw16o16i
,
Oidhw16o
=
mkldnn_Oidhw16o
,
Odhwi16o
=
mkldnn_Odhwi16o
,
oIhw8i
=
mkldnn_oIhw8i
,
oIhw16i
=
mkldnn_oIhw16i
,
OIhw8i8o
=
mkldnn_OIhw8i8o
,
OIhw16i16o
=
mkldnn_OIhw16i16o
,
OIhw8o8i
=
mkldnn_OIhw8o8i
,
OIhw16o16i
=
mkldnn_OIhw16o16i
,
IOhw16o16i
=
mkldnn_IOhw16o16i
,
OIhw8i16o2i
=
mkldnn_OIhw8i16o2i
,
OIhw8o16i2o
=
mkldnn_OIhw8o16i2o
,
OIhw4i16o4i
=
mkldnn_OIhw4i16o4i
,
Oihw8o
=
mkldnn_Oihw8o
,
Oihw16o
=
mkldnn_Oihw16o
,
Ohwi8o
=
mkldnn_Ohwi8o
,
Ohwi16o
=
mkldnn_Ohwi16o
,
OhIw16o4i
=
mkldnn_OhIw16o4i
,
goihw
=
mkldnn_goihw
,
hwigo
=
mkldnn_hwigo
,
gOIhw8i8o
=
mkldnn_gOIhw8i8o
,
gOIhw16i16o
=
mkldnn_gOIhw16i16o
,
gOIhw8i16o2i
=
mkldnn_gOIhw8i16o2i
,
gOIhw8o16i2o
=
mkldnn_gOIhw8o16i2o
,
gOIhw4i16o4i
=
mkldnn_gOIhw4i16o4i
,
gOihw8o
=
mkldnn_gOihw8o
,
gOihw16o
=
mkldnn_gOihw16o
,
gOhwi8o
=
mkldnn_gOhwi8o
,
gOhwi16o
=
mkldnn_gOhwi16o
,
Goihw8g
=
mkldnn_Goihw8g
,
Goihw16g
=
mkldnn_Goihw16g
,
gOIhw8o8i
=
mkldnn_gOIhw8o8i
,
gOIhw16o16i
=
mkldnn_gOIhw16o16i
,
gIOhw16o16i
=
mkldnn_gIOhw16o16i
,
gOhIw16o4i
=
mkldnn_gOhIw16o4i
,
goidhw
=
mkldnn_goidhw
,
gOIdhw16i16o
=
mkldnn_gOIdhw16i16o
,
gOIdhw16o16i
=
mkldnn_gOIdhw16o16i
,
gOidhw16o
=
mkldnn_gOidhw16o
,
gOdhwi16o
=
mkldnn_gOdhwi16o
,
ntc
=
mkldnn_ntc
,
tnc
=
mkldnn_tnc
,
ldsnc
=
mkldnn_ldsnc
,
ldigo
=
mkldnn_ldigo
,
ldigo_p
=
mkldnn_ldigo_p
,
ldgoi
=
mkldnn_ldgoi
,
ldgoi_p
=
mkldnn_ldgoi_p
,
ldgo
=
mkldnn_ldgo
,
wino_fmt
=
mkldnn_wino_fmt
,
format_last
=
mkldnn_format_last
,
};
/// A memory descriptor.
struct
desc
{
friend
struct
memory
;
/// The underlying C API data structure.
mkldnn_memory_desc_t
data
;
/// Constructs a memory descriptor.
///
/// @param adims Data dimensions
/// @param adata_type Data precision/type.
/// @param aformat Data layout format.
desc
(
dims
adims
,
data_type
adata_type
,
format
aformat
)
{
validate_dims
(
adims
);
error
::
wrap_c_api
(
mkldnn_memory_desc_init
(
&
data
,
(
int
)
adims
.
size
(),
adims
.
size
()
==
0
?
nullptr
:
&
adims
[
0
],
convert_to_c
(
adata_type
),
convert_to_c
(
aformat
)),
"could not initialize a memory descriptor"
);
}
/// Constructs a memory descriptor from a C API data structure.
///
/// @param adata A C API #mkldnn_memory_desc_t structure.
desc
(
const
mkldnn_memory_desc_t
&
adata
)
:
data
(
adata
)
{}
};
/// A memory primitive descriptor.
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
friend
struct
memory
;
// TODO: make private
primitive_desc
()
{}
/// Constructs a memory primitive descriptor.
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_memory_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
()),
"could not initialize a memory primitive descriptor"
);
reset
(
result
);
}
/// Returns the memory primitive descriptor.
memory
::
desc
desc
()
{
auto
memory_d
=
mkldnn_primitive_desc_query_memory_d
(
get
());
return
memory
::
desc
(
*
memory_d
);
}
/// Returns the number of bytes required to allocate the memory described
/// including the padding area.
size_t
get_size
()
const
{
return
mkldnn_memory_primitive_desc_get_size
(
get
());
}
bool
operator
==
(
const
primitive_desc
&
other
)
const
{
return
mkldnn_memory_primitive_desc_equal
(
get
(),
other
.
get
());
}
bool
operator
!=
(
const
primitive_desc
&
other
)
const
{
return
!
operator
==
(
other
);
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
/// Constructs a memory primitive from a generic primitive.
///
/// @param aprimitive The primitive to treat as memory.
memory
(
const
primitive
&
aprimitive
)
:
primitive
(
aprimitive
)
{}
/// Constructs a memory primitive.
///
/// @param adesc Memory primitive descriptor.
memory
(
const
primitive_desc
&
adesc
)
{
mkldnn_primitive_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
adesc
.
get
(),
nullptr
,
nullptr
),
"could not create a memory primitive"
);
reset
(
result
);
auto
_malloc
=
[](
size_t
size
,
int
alignment
)
{
void
*
ptr
;
#ifdef _WIN32
ptr
=
_aligned_malloc
(
size
,
alignment
);
int
rc
=
((
ptr
)
?
0
:
errno
);
#else
int
rc
=
::
posix_memalign
(
&
ptr
,
alignment
,
size
);
#endif
/* _WIN32 */
return
(
rc
==
0
)
?
(
char
*
)
ptr
:
nullptr
;
};
auto
_free
=
[](
char
*
p
)
{
#ifdef _WIN32
_aligned_free
((
void
*
)
p
);
#else
::
free
((
void
*
)
p
);
#endif
/* _WIN32 */
};
_handle
.
reset
(
_malloc
(
adesc
.
get_size
(),
4096
),
_free
);
set_data_handle
(
_handle
.
get
());
}
memory
(
const
primitive_desc
&
adesc
,
void
*
ahandle
)
{
mkldnn_primitive_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
adesc
.
get
(),
nullptr
,
nullptr
),
"could not create a memory primitive"
);
reset
(
result
);
set_data_handle
(
ahandle
);
}
/// Returns the descriptor of the memory primitive.
primitive_desc
get_primitive_desc
()
const
{
primitive_desc
adesc
;
const_mkldnn_primitive_desc_t
cdesc
;
error
::
wrap_c_api
(
mkldnn_primitive_get_primitive_desc
(
get
(),
&
cdesc
),
"could not get primitive descriptor from a memory primitive"
);
/* FIXME: no const_cast should be here */
adesc
.
reset
(
const_cast
<
mkldnn_primitive_desc_t
>
(
cdesc
),
true
);
return
adesc
;
}
/// Returns a handle of the data contained in the memory primitive. On
/// the CPU engine, this is a pointer to the allocated memory.
inline
void
*
get_data_handle
()
const
{
void
*
handle
;
error
::
wrap_c_api
(
mkldnn_memory_get_data_handle
(
get
(),
&
handle
),
"could not get native handle"
);
return
handle
;
}
inline
void
set_data_handle
(
void
*
handle
)
const
{
error
::
wrap_c_api
(
mkldnn_memory_set_data_handle
(
get
(),
handle
),
"could not set native handle"
);
}
// Must go away or be private:
static
mkldnn_data_type_t
convert_to_c
(
data_type
adata_type
)
{
return
static_cast
<
mkldnn_data_type_t
>
(
adata_type
);
}
static
mkldnn_memory_format_t
convert_to_c
(
format
aformat
)
{
return
static_cast
<
mkldnn_memory_format_t
>
(
aformat
);
}
};
inline
memory
::
desc
zero_md
()
{
mkldnn_memory_desc_t
zero
;
zero
.
primitive_kind
=
mkldnn_memory
;
return
memory
::
desc
(
zero
);
}
inline
memory
null_memory
(
engine
eng
)
{
mkldnn
::
memory
::
desc
zero
=
zero_md
();
return
memory
({
zero
,
eng
},
nullptr
);
}
inline
bool
is_null_memory
(
const
const_mkldnn_primitive_t
&
aprimitive
)
{
const_mkldnn_primitive_desc_t
aprimitive_pd
;
mkldnn_primitive_get_primitive_desc
(
aprimitive
,
&
aprimitive_pd
);
const
mkldnn_memory_desc_t
*
aprimitive_md
=
mkldnn_primitive_desc_query_memory_d
(
aprimitive_pd
);
return
((
aprimitive_md
!=
nullptr
)
&&
(
aprimitive_md
->
ndims
==
0
));
}
inline
bool
operator
==
(
mkldnn_data_type_t
a
,
memory
::
data_type
b
)
{
return
a
==
memory
::
convert_to_c
(
b
);
}
inline
bool
operator
!=
(
mkldnn_data_type_t
a
,
memory
::
data_type
b
)
{
return
!
(
a
==
b
);
}
inline
bool
operator
==
(
memory
::
data_type
a
,
mkldnn_data_type_t
b
)
{
return
b
==
a
;
}
inline
bool
operator
!=
(
memory
::
data_type
a
,
mkldnn_data_type_t
b
)
{
return
!
(
a
==
b
);
}
inline
bool
operator
==
(
mkldnn_memory_format_t
a
,
memory
::
format
b
)
{
return
a
==
memory
::
convert_to_c
(
b
);
}
inline
bool
operator
!=
(
mkldnn_memory_format_t
a
,
memory
::
format
b
)
{
return
!
(
a
==
b
);
}
inline
bool
operator
==
(
memory
::
format
a
,
mkldnn_memory_format_t
b
)
{
return
b
==
a
;
}
inline
bool
operator
!=
(
memory
::
format
a
,
mkldnn_memory_format_t
b
)
{
return
!
(
a
==
b
);
}
/// @}
/// @addtogroup cpp_api_reorder Reorder
/// @{
struct
reorder
:
public
primitive
{
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
memory
::
primitive_desc
&
input
,
const
memory
::
primitive_desc
&
output
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_reorder_primitive_desc_create
(
&
result
,
input
.
get
(),
output
.
get
()),
"could not create a reorder primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
memory
::
primitive_desc
&
input
,
const
memory
::
primitive_desc
&
output
,
const
primitive_attr
&
aattr
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_reorder_primitive_desc_create_v2
(
&
result
,
input
.
get
(),
output
.
get
(),
aattr
.
get
()),
"could not create a reorder primitive descriptor"
);
reset
(
result
);
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
reorder
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
input
,
const
memory
&
output
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
input
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
output
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a reorder primitive"
);
reset
(
result
);
}
reorder
(
const
primitive
::
at
&
input
,
const
memory
&
output
)
{
auto
input_mpd
=
memory
(
input
).
get_primitive_desc
();
auto
output_mpd
=
output
.
get_primitive_desc
();
auto
reorder_d
=
primitive_desc
(
input_mpd
,
output_mpd
);
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
input
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
output
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
reorder_d
.
get
(),
inputs
,
outputs
),
"could not create a reorder primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_view View
/// @{
struct
view
:
public
primitive
{
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
memory
::
primitive_desc
&
input
,
memory
::
dims
dims
,
memory
::
dims
offsets
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_view_primitive_desc_create
(
&
result
,
input
.
get
(),
&
dims
[
0
],
&
offsets
[
0
]),
"could not create a view primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
view
(
const
primitive_desc
&
view_pd
,
primitive
::
at
input
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
input
.
data
};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
view_pd
.
get
(),
inputs
,
nullptr
),
"could not create a view primitive"
);
reset
(
result
);
}
view
(
memory
input
,
memory
::
dims
dims
,
memory
::
dims
offsets
)
{
mkldnn_primitive_t
result
;
primitive_desc
view_pd
(
input
.
get_primitive_desc
(),
dims
,
offsets
);
mkldnn_primitive_at_t
inputs
[]
=
{
primitive
::
at
(
input
).
data
};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
view_pd
.
get
(),
inputs
,
nullptr
),
"could not create a view primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_concat Concat
/// @{
struct
concat
:
public
primitive
{
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
std
::
vector
<
const_mkldnn_primitive_desc_t
>
cpp_to_c
(
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
std
::
vector
<
const_mkldnn_primitive_desc_t
>
c_api_inputs
;
c_api_inputs
.
reserve
(
inputs
.
size
());
auto
convert_to_c
=
[](
memory
::
primitive_desc
d
)
{
return
d
.
get
();
};
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
c_api_inputs
),
convert_to_c
);
return
c_api_inputs
;
}
primitive_desc
(
const
memory
::
desc
&
output
,
int
concat_dimension
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
error
::
wrap_c_api
(
mkldnn_concat_primitive_desc_create
(
&
result
,
&
output
.
data
,
(
int
)
c_api_inputs
.
size
(),
concat_dimension
,
&
c_api_inputs
[
0
]),
"could not create a concat primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
int
concat_dimension
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
error
::
wrap_c_api
(
mkldnn_concat_primitive_desc_create
(
&
result
,
nullptr
,
(
int
)
c_api_inputs
.
size
(),
concat_dimension
,
&
c_api_inputs
[
0
]),
"could not create a concat primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
concat
(
const
primitive_desc
&
concat_pd
,
std
::
vector
<
primitive
::
at
>
&
inputs
,
const
memory
&
output
)
{
mkldnn_primitive_t
result
;
std
::
vector
<
mkldnn_primitive_at_t
>
p_inputs
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
p_inputs
.
push_back
(
inputs
[
i
].
data
);
const_mkldnn_primitive_t
outputs
[]
=
{
output
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
concat_pd
.
get
(),
&
p_inputs
[
0
],
outputs
),
"could not create a concat primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_sum Sum
/// @{
struct
sum
:
public
primitive
{
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
std
::
vector
<
const_mkldnn_primitive_desc_t
>
cpp_to_c
(
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
std
::
vector
<
const_mkldnn_primitive_desc_t
>
c_api_inputs
;
c_api_inputs
.
reserve
(
inputs
.
size
());
auto
convert_to_c
=
[](
memory
::
primitive_desc
d
)
{
return
d
.
get
();
};
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
c_api_inputs
),
convert_to_c
);
return
c_api_inputs
;
}
primitive_desc
(
const
memory
::
desc
&
output
,
const
std
::
vector
<
float
>
&
scales
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
error
::
wrap_c_api
(
mkldnn_sum_primitive_desc_create
(
&
result
,
&
output
.
data
,
(
int
)
c_api_inputs
.
size
(),
&
scales
[
0
],
&
c_api_inputs
[
0
]),
"could not create a sum primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
std
::
vector
<
float
>
&
scales
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
error
::
wrap_c_api
(
mkldnn_sum_primitive_desc_create
(
&
result
,
nullptr
,
(
int
)
c_api_inputs
.
size
(),
&
scales
[
0
],
&
c_api_inputs
[
0
]),
"could not create a sum primitive descriptor"
);
reset
(
result
);
}
/** @deprecated: api backwards compatibility for double scales type */
MKLDNN_DEPRECATED
primitive_desc
(
const
memory
::
desc
&
output
,
std
::
vector
<
double
>
scale
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
auto
scale_f
=
scale_to_float
(
scale
);
error
::
wrap_c_api
(
mkldnn_sum_primitive_desc_create
(
&
result
,
&
output
.
data
,
(
int
)
c_api_inputs
.
size
(),
&
scale_f
[
0
],
&
c_api_inputs
[
0
]),
"could not create a sum primitive descriptor"
);
reset
(
result
);
}
/** @deprecated: api backwards compatibility for double scales type */
MKLDNN_DEPRECATED
primitive_desc
(
std
::
vector
<
double
>
scale
,
std
::
vector
<
memory
::
primitive_desc
>
inputs
)
{
mkldnn_primitive_desc_t
result
;
auto
c_api_inputs
=
cpp_to_c
(
inputs
);
auto
scale_f
=
scale_to_float
(
scale
);
error
::
wrap_c_api
(
mkldnn_sum_primitive_desc_create
(
&
result
,
nullptr
,
(
int
)
c_api_inputs
.
size
(),
&
scale_f
[
0
],
&
c_api_inputs
[
0
]),
"could not create a sum primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
sum
(
const
primitive_desc
&
sum_pd
,
std
::
vector
<
primitive
::
at
>
&
inputs
,
const
memory
&
output
)
{
mkldnn_primitive_t
result
;
std
::
vector
<
mkldnn_primitive_at_t
>
p_inputs
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
p_inputs
.
push_back
(
inputs
[
i
].
data
);
const_mkldnn_primitive_t
outputs
[]
=
{
output
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
sum_pd
.
get
(),
&
p_inputs
[
0
],
outputs
),
"could not create a sum primitive"
);
reset
(
result
);
}
private:
static
std
::
vector
<
float
>
scale_to_float
(
const
std
::
vector
<
double
>
&
vd
)
{
std
::
vector
<
float
>
vf
(
vd
.
size
());
std
::
transform
(
vd
.
begin
(),
vd
.
end
(),
vf
.
begin
(),
[
=
](
double
x
)
{
return
(
float
)
x
;
});
return
vf
;
}
};
/// @}
/// @addtogroup cpp_api_convolution Convolution
/// @{
struct
convolution_forward
:
public
primitive
{
struct
desc
{
mkldnn_convolution_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
&
bias_desc
.
data
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
nullptr
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
&
bias_desc
.
data
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a dilated convolution forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
nullptr
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a dilated convolution forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a convolution forward primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
desc
&
adesc
,
const
primitive_attr
&
aattr
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create_v2
(
&
result
,
&
adesc
.
data
,
aattr
.
get
(),
aengine
.
get
(),
nullptr
),
"could not create a convolution forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
convolution_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
primitive
::
at
&
bias
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
,
bias
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution forward bias primitive"
);
reset
(
result
);
}
convolution_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution forward primitive"
);
reset
(
result
);
}
};
struct
convolution_backward_data
:
public
primitive
{
struct
desc
{
mkldnn_convolution_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_backward_data_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_src_desc
.
data
,
&
weights_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward data descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_backward_data_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_src_desc
.
data
,
&
weights_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward data descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
convolution_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a convolution backward data primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
convolution_backward_data
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution backward data primitive"
);
reset
(
result
);
}
};
struct
convolution_backward_weights
:
public
primitive
{
struct
desc
{
mkldnn_convolution_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward weights descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_convolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
nullptr
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward weights descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward weights descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
dilates
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
dilates
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_dilated_convolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
nullptr
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
dilates
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a convolution backward weights descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
convolution_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a convolution backward weights primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
convolution_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_weights
,
const
memory
&
diff_bias
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
(),
diff_bias
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution backward weights primitive"
);
reset
(
result
);
}
convolution_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution backward weights primitive"
);
reset
(
result
);
}
};
struct
convolution_relu_forward
:
public
primitive
{
struct
desc
{
mkldnn_convolution_relu_desc_t
data
;
desc
(
const
convolution_forward
::
desc
conv_desc
,
const
float
negative_slope
)
{
error
::
wrap_c_api
(
mkldnn_convolution_relu_desc_init
(
&
data
,
&
conv_desc
.
data
,
negative_slope
),
"could not create a convolution_relu_forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a convolution relu forward descriptor"
);
reset
(
result
);
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
convolution_relu_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
primitive
::
at
&
bias
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
,
bias
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution relu forward primitive"
);
reset
(
result
);
}
convolution_relu_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a convolution relu forward primitive"
);
reset
(
result
);
}
};
/// @}
//
/// @addtogroup cpp_api_deconvolution Deconvolution
/// @{
struct
deconvolution_forward
:
public
primitive
{
struct
desc
{
mkldnn_deconvolution_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
&
bias_desc
.
data
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
nullptr
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a deconvolution forward primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
desc
&
adesc
,
const
primitive_attr
&
aattr
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create_v2
(
&
result
,
&
adesc
.
data
,
aattr
.
get
(),
aengine
.
get
(),
nullptr
),
"could not create a deconvolution forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
deconvolution_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
primitive
::
at
&
bias
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
,
bias
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution forward bias primitive"
);
reset
(
result
);
}
deconvolution_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution forward primitive"
);
reset
(
result
);
}
};
struct
deconvolution_backward_data
:
public
primitive
{
struct
desc
{
mkldnn_deconvolution_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_backward_data_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_src_desc
.
data
,
&
weights_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution backward data descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
deconvolution_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a deconvolution backward data primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
deconvolution_backward_data
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution backward data primitive"
);
reset
(
result
);
}
};
struct
deconvolution_backward_weights
:
public
primitive
{
struct
desc
{
mkldnn_deconvolution_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution backward weights descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_deconvolution_backward_weights_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
nullptr
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not create a deconvolution backward weights descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
deconvolution_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a deconvolution backward weights primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
deconvolution_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_weights
,
const
memory
&
diff_bias
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
(),
diff_bias
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution backward weights primitive"
);
reset
(
result
);
}
deconvolution_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a deconvolution backward weights primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_lrn LRN
/// @{
struct
lrn_forward
:
public
primitive
{
struct
desc
{
mkldnn_lrn_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
int
local_size
,
float
alpha
,
float
beta
,
float
k
)
{
error
::
wrap_c_api
(
mkldnn_lrn_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
local_size
,
alpha
,
beta
,
k
),
"could not create a lrn forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
int
local_size
,
float
alpha
,
float
beta
)
{
error
::
wrap_c_api
(
mkldnn_lrn_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
local_size
,
alpha
,
beta
,
float
(
1.0
)),
"could not create a lrn forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a lrn forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
ldesc
;
const_mkldnn_primitive_desc_t
const_ldesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
ldesc
,
const_ldesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
ldesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
lrn_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
workspace
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
workspace
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a lrn forward primitive"
);
reset
(
result
);
}
lrn_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a lrn forward primitive"
);
reset
(
result
);
}
};
struct
lrn_backward
:
public
primitive
{
struct
desc
{
mkldnn_lrn_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
data_desc
,
const
memory
::
desc
&
diff_data_desc
,
int
local_size
,
float
alpha
,
float
beta
,
float
k
)
{
error
::
wrap_c_api
(
mkldnn_lrn_backward_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_data_desc
.
data
,
&
data_desc
.
data
,
local_size
,
alpha
,
beta
,
k
),
"could not create a lrn backward descriptor"
);
}
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
data_desc
,
const
memory
::
desc
&
diff_data_desc
,
int
local_size
,
float
alpha
,
float
beta
)
{
error
::
wrap_c_api
(
mkldnn_lrn_backward_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_data_desc
.
data
,
&
data_desc
.
data
,
local_size
,
alpha
,
beta
,
float
(
1.0
)),
"could not create a lrn backward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
lrn_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a backward lrn primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
ldesc
;
const_mkldnn_primitive_desc_t
const_ldesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
ldesc
,
const_ldesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
ldesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff_dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
lrn_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
workspace
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
,
workspace
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a lrn backward primitive"
);
reset
(
result
);
}
lrn_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a lrn backward primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_pooling Pooling
/// @{
struct
pooling_forward
:
public
primitive
{
struct
desc
{
mkldnn_pooling_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
algorithm
aalgorithm
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
dst_desc
,
const
memory
::
dims
strides
,
const
memory
::
dims
kernel
,
const
memory
::
dims
padding_l
,
const
memory
::
dims
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
kernel
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_pooling_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
convert_to_c
(
aalgorithm
),
&
src_desc
.
data
,
&
dst_desc
.
data
,
&
strides
[
0
],
&
kernel
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not init a forward pooling descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a forward pooling primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a workspace primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
pooling_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
nullptr
};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a pooling forward primitive"
);
reset
(
result
);
}
pooling_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
,
const
memory
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
workspace
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a pooling forward primitive"
);
reset
(
result
);
}
};
struct
pooling_backward
:
public
primitive
{
struct
desc
{
mkldnn_pooling_desc_t
data
;
desc
(
algorithm
aalgorithm
,
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
diff_dst_desc
,
const
memory
::
dims
&
strides
,
const
memory
::
dims
&
kernel
,
const
memory
::
dims
&
padding_l
,
const
memory
::
dims
&
padding_r
,
const
padding_kind
apadding_kind
)
{
memory
::
validate_dims
(
strides
);
memory
::
validate_dims
(
kernel
);
memory
::
validate_dims
(
padding_l
);
memory
::
validate_dims
(
padding_r
);
error
::
wrap_c_api
(
mkldnn_pooling_backward_desc_init
(
&
data
,
convert_to_c
(
aalgorithm
),
&
diff_src_desc
.
data
,
&
diff_dst_desc
.
data
,
&
strides
[
0
],
&
kernel
[
0
],
&
padding_l
[
0
],
&
padding_r
[
0
],
mkldnn
::
convert_to_c
(
apadding_kind
)),
"could not init a backward pooling descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
pooling_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a backward pooling primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
pooling_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a pooling backward primitive"
);
reset
(
result
);
}
pooling_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
workspace
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
,
workspace
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a pooling backward primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_eltwise Eltwise
/// @{
struct
eltwise_forward
:
public
primitive
{
struct
desc
{
mkldnn_eltwise_desc_t
data
;
template
<
typename
T
>
desc
(
prop_kind
aprop_kind
,
algorithm
alg_kind
,
const
memory
::
desc
&
src_desc
,
T
alpha
=
0
,
T
beta
=
0
)
{
error
::
wrap_c_api
(
mkldnn_eltwise_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
mkldnn
::
convert_to_c
(
alg_kind
),
&
src_desc
.
data
,
static_cast
<
float
>
(
alpha
),
static_cast
<
float
>
(
beta
)),
"could not create a eltwise forward descriptor"
);
}
/** @deprecated: api backward compatibility for relu */
template
<
typename
T
>
MKLDNN_DEPRECATED
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
src_desc
,
T
negative_slope
)
:
desc
(
aprop_kind
,
eltwise_relu
,
src_desc
,
negative_slope
)
{}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a eltwise forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
eltwise_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a eltwise forward primitive"
);
reset
(
result
);
}
};
typedef
eltwise_forward
relu_forward
;
struct
eltwise_backward
:
public
primitive
{
struct
desc
{
mkldnn_eltwise_desc_t
data
;
template
<
typename
T
>
desc
(
algorithm
alg_kind
,
const
memory
::
desc
&
diff_data_desc
,
const
memory
::
desc
&
data_desc
,
T
alpha
=
0
,
T
beta
=
0
)
{
error
::
wrap_c_api
(
mkldnn_eltwise_backward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
alg_kind
),
&
diff_data_desc
.
data
,
&
data_desc
.
data
,
static_cast
<
float
>
(
alpha
),
static_cast
<
float
>
(
beta
)),
"could not create a eltwise backward descriptor"
);
}
/** @deprecated: api backward compatibility for relu */
template
<
typename
T
>
MKLDNN_DEPRECATED
desc
(
const
memory
::
desc
&
diff_data_desc
,
const
memory
::
desc
&
data_desc
,
T
negative_slope
)
:
desc
(
eltwise_relu
,
diff_data_desc
,
data_desc
,
negative_slope
)
{}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
eltwise_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a eltwise backward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
eltwise_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a eltwise backward primitive"
);
reset
(
result
);
}
};
typedef
eltwise_backward
relu_backward
;
/// @}
/// @addtogroup cpp_api_softmax Softmax
/// @{
struct
softmax_forward
:
public
primitive
{
struct
desc
{
mkldnn_softmax_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
data_desc
,
int
softmax_axis
)
{
error
::
wrap_c_api
(
mkldnn_softmax_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
data_desc
.
data
,
softmax_axis
),
"could not create a softmax forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a softmax forward primitive descriptor"
);
reset
(
result
);
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
softmax_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a softmax forward primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_batch_norm Batch normalization
/// @{
struct
batch_normalization_forward
:
public
primitive
{
struct
desc
{
mkldnn_batch_normalization_desc_t
data
;
template
<
typename
T
>
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
src_desc
,
T
epsilon
,
unsigned
flags
)
{
error
::
wrap_c_api
(
mkldnn_batch_normalization_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
src_desc
.
data
,
static_cast
<
float
>
(
epsilon
),
flags
),
"could not create a batch normalization forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a batch normalization forward "
"primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
desc
&
adesc
,
const
primitive_attr
&
aattr
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create_v2
(
&
result
,
&
adesc
.
data
,
aattr
.
get
(),
aengine
.
get
(),
nullptr
),
"could not create a batch normalization forward "
"primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
mean_primitive_desc
()
const
{
memory
::
primitive_desc
aprimitive_desc
;
mkldnn_primitive_desc_t
bndesc
;
mkldnn_batch_normalization_desc_t
*
p
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_query
(
get
(),
mkldnn
::
convert_to_c
(
batch_normalization_d
),
0
,
&
p
),
"could not get a batch-normalization descriptor"
);
const_mkldnn_primitive_desc_t
const_bndesc
=
(
p
->
flags
&
use_global_stats
)
?
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
1
)
:
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a mean primitive descriptor"
);
aprimitive_desc
.
reset
(
bndesc
);
return
aprimitive_desc
;
}
memory
::
primitive_desc
variance_primitive_desc
()
const
{
memory
::
primitive_desc
aprimitive_desc
;
mkldnn_primitive_desc_t
bndesc
;
mkldnn_batch_normalization_desc_t
*
p
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_query
(
get
(),
mkldnn
::
convert_to_c
(
batch_normalization_d
),
0
,
&
p
),
"could not get a batch-normalization descriptor"
);
const_mkldnn_primitive_desc_t
const_bndesc
=
(
p
->
flags
&
use_global_stats
)
?
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
2
)
:
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a variance primitive descriptor"
);
aprimitive_desc
.
reset
(
bndesc
);
return
aprimitive_desc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
/// @warning batch_normalization_forward has 2 constructors with very
/// similar signatures:
/// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
/// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
/// The only way to distinguish between those is to explicitly
/// cast all input parameters to their type, i.e. to
/// const primitive:at &.
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
,
const
memory
&
mean
,
const
memory
&
variance
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
mean
.
get
(),
variance
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
,
const
memory
&
mean
,
const
memory
&
variance
,
const
memory
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
mean
.
get
(),
variance
.
get
(),
workspace
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
,
const
memory
&
mean
,
const
memory
&
variance
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
(),
mean
.
get
(),
variance
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
/// @warning batch_normalization_forward has 2 constructors with very
/// similar signatures:
/// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
/// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
/// The only way to distinguish between those is to explicitly
/// cast all input parameters to their type, i.e. to
/// const primitive:at &.
/// @note to make users' experience a little bit better this constructor
/// checks if whether parameters match corresponding primitive
/// descriptor, and if they are not -- call the other (proper)
/// constructor. Yeah, this is still very ugly...
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
,
const
memory
&
mean
,
const
memory
&
variance
,
const
memory
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[
2
]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[
4
]
=
{
dst
.
get
(),
mean
.
get
(),
variance
.
get
(),
workspace
.
get
()};
if
(
1
)
{
// check whether this is the `wrong` constructor
const
int
n_inputs_expected
=
mkldnn_primitive_desc_query_s32
(
aprimitive_desc
.
get
(),
mkldnn_query_num_of_inputs_s32
,
0
);
const
int
n_outputs_expected
=
mkldnn_primitive_desc_query_s32
(
aprimitive_desc
.
get
(),
mkldnn_query_num_of_outputs_s32
,
0
);
if
(
n_inputs_expected
==
2
&&
n_outputs_expected
==
3
)
{
// shift parameters, get rid of workspace, and add weights...
auto
_weights
=
dst
;
inputs
[
1
]
=
{
_weights
.
get
(),
0
};
auto
_dst
=
mean
,
_mean
=
variance
,
_variance
=
workspace
;
outputs
[
0
]
=
_dst
.
get
();
outputs
[
1
]
=
_mean
.
get
();
outputs
[
2
]
=
_variance
.
get
();
outputs
[
3
]
=
nullptr
;
}
}
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
batch_normalization_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization forward primitive"
);
reset
(
result
);
}
};
struct
batch_normalization_backward
:
public
primitive
{
struct
desc
{
mkldnn_batch_normalization_desc_t
data
;
template
<
typename
T
>
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
diff_data_desc
,
const
memory
::
desc
&
data_desc
,
T
epsilon
,
unsigned
flags
)
{
error
::
wrap_c_api
(
mkldnn_batch_normalization_backward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
diff_data_desc
.
data
,
&
data_desc
.
data
,
static_cast
<
float
>
(
epsilon
),
flags
),
"could not create a batch normalization backward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
batch_normalization_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a batch normalization backward primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a diff_weights primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
mean_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a mean primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
variance_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
bndesc
;
const_mkldnn_primitive_desc_t
const_bndesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
bndesc
,
const_bndesc
),
"could not clone a variance primitive descriptor"
);
adesc
.
reset
(
bndesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
// Prop_kind == backward
batch_normalization_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights
,
const
memory
&
diff_src
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
diff_dst
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
(),
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization backward primitive"
);
reset
(
result
);
}
// Prop_kind == backward (+ws)
batch_normalization_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights
,
const
primitive
::
at
&
workspace
,
const
memory
&
diff_src
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
diff_dst
.
data
,
weights
.
data
,
workspace
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
(),
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization backward primitive"
);
reset
(
result
);
}
// Prop_kind == backward_data (+ws or +weights)
/// @warning This constructor works for backward_data propagation
/// - w/ weights but w/o workspace, or
/// - w/ workspace but w/o weights
batch_normalization_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
&
weights_or_workspace
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
diff_dst
.
data
,
weights_or_workspace
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization backward primitive"
);
reset
(
result
);
}
// Prop_kind == backward_data
batch_normalization_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
&
mean
,
const
primitive
::
at
&
variance
,
const
primitive
::
at
&
diff_dst
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
mean
.
data
,
variance
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a batch normalization backward primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_inner_product Inner Product
/// @{
struct
inner_product_forward
:
public
primitive
{
struct
desc
{
mkldnn_inner_product_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
&
bias_desc
.
data
,
&
dst_desc
.
data
),
"could not create a inner product forward descriptor"
);
}
desc
(
prop_kind
aprop_kind
,
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
&
src_desc
.
data
,
&
weights_desc
.
data
,
nullptr
,
&
dst_desc
.
data
),
"could not create a inner product forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create a inner product forward primitive descriptor"
);
reset
(
result
);
}
primitive_desc
(
const
desc
&
adesc
,
const
primitive_attr
&
aattr
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create_v2
(
&
result
,
&
adesc
.
data
,
aattr
.
get
(),
aengine
.
get
(),
nullptr
),
"could not create a inner product "
"forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
inner_product_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
weights
,
const
primitive
::
at
&
bias
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
,
bias
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product forward primitive"
);
reset
(
result
);
}
inner_product_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
weights
,
const
memory
&
dst
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
dst
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product forward primitive"
);
reset
(
result
);
}
};
struct
inner_product_backward_data
:
public
primitive
{
struct
desc
{
mkldnn_inner_product_desc_t
data
;
desc
(
const
memory
::
desc
&
diff_src_desc
,
const
memory
::
desc
&
weights_desc
,
const
memory
::
desc
&
diff_dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_backward_data_desc_init
(
&
data
,
&
diff_src_desc
.
data
,
&
weights_desc
.
data
,
&
diff_dst_desc
.
data
),
"could not create a inner product backward data descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
inner_product_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a inner product backward data primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff dst primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
inner_product_backward_data
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
diff_dst
,
const
primitive
::
at
weights
,
const
memory
&
diff_src
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
diff_dst
.
data
,
weights
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_src
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product backward data primitive"
);
reset
(
result
);
}
};
struct
inner_product_backward_weights
:
public
primitive
{
struct
desc
{
mkldnn_inner_product_desc_t
data
;
desc
(
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_backward_weights_desc_init
(
&
data
,
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_desc
.
data
),
"could not create a inner product backward weights descriptor"
);
}
desc
(
const
memory
::
desc
&
src_desc
,
const
memory
::
desc
&
diff_weights_desc
,
const
memory
::
desc
&
diff_dst_desc
)
{
error
::
wrap_c_api
(
mkldnn_inner_product_backward_weights_desc_init
(
&
data
,
&
src_desc
.
data
,
&
diff_weights_desc
.
data
,
nullptr
,
&
diff_dst_desc
.
data
),
"could not create a inner product backward weights descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
,
const
inner_product_forward
::
primitive_desc
&
hint_fwd_primitive_desc
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
hint_fwd_primitive_desc
.
get
()),
"could not create a inner product backward weights primitive "
"descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
diff_dst_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff dst primititve descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a diff bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
src_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
inner_product_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
diff_dst
,
const
memory
&
diff_weights
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product backward weights primitive"
);
reset
(
result
);
}
inner_product_backward_weights
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src
,
const
primitive
::
at
diff_dst
,
const
memory
&
diff_weights
,
const
memory
&
diff_bias
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[]
=
{
src
.
data
,
diff_dst
.
data
};
const_mkldnn_primitive_t
outputs
[]
=
{
diff_weights
.
get
(),
diff_bias
.
get
()};
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create a inner product backward weights primitive"
);
reset
(
result
);
}
};
/// @}
/// @addtogroup cpp_api_rnn RNN
/// @{
struct
rnn_cell
{
struct
desc
{
mkldnn_rnn_cell_desc_t
c_rnn_cell_
;
desc
(
algorithm
kind
,
algorithm
activation_f
)
{
error
::
wrap_c_api
(
mkldnn_rnn_cell_desc_init
(
&
c_rnn_cell_
,
mkldnn
::
convert_to_c
(
kind
),
mkldnn
::
convert_to_c
(
activation_f
),
0U
,
0
,
0
),
"could not init an rnn cell descriptor"
);
}
desc
(
algorithm
kind
)
:
desc
(
kind
,
algorithm
::
algorithm_undef
)
{}
operator
const
mkldnn_rnn_cell_desc_t
*
()
const
{
return
&
c_rnn_cell_
;
}
algorithm
get_cell_kind
()
const
{
return
algorithm
(
c_rnn_cell_
.
cell_kind
);
}
algorithm
get_activation
()
const
{
return
algorithm
(
c_rnn_cell_
.
activation_kind
);
}
float
get_alpha
()
const
{
return
c_rnn_cell_
.
alpha
;
}
void
set_alpha
(
float
alpha
)
{
c_rnn_cell_
.
flags
|=
mkldnn_rnn_cell_with_relu
;
c_rnn_cell_
.
alpha
=
alpha
;
}
float
get_clipping
()
const
{
return
c_rnn_cell_
.
clipping
;
}
void
set_clipping
(
float
clipping
)
{
c_rnn_cell_
.
flags
|=
mkldnn_rnn_cell_with_clipping
;
c_rnn_cell_
.
clipping
=
clipping
;
}
int
get_gates_count
()
const
{
return
mkldnn_rnn_cell_get_gates_count
(
&
c_rnn_cell_
);
}
int
get_state_count
()
const
{
return
mkldnn_rnn_cell_get_states_count
(
&
c_rnn_cell_
);
}
};
};
struct
rnn_forward
:
public
primitive
{
struct
desc
{
mkldnn_rnn_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
rnn_cell
::
desc
cell
,
const
rnn_direction
direction
,
const
memory
::
desc
&
src_layer_desc
,
const
memory
::
desc
&
src_iter_desc
,
const
memory
::
desc
&
weights_layer_desc
,
const
memory
::
desc
&
weights_iter_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_layer_desc
,
const
memory
::
desc
&
dst_iter_desc
)
{
error
::
wrap_c_api
(
mkldnn_rnn_forward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
cell
,
mkldnn
::
convert_to_c
(
direction
),
&
src_layer_desc
.
data
,
&
src_iter_desc
.
data
,
&
weights_layer_desc
.
data
,
&
weights_iter_desc
.
data
,
&
bias_desc
.
data
,
&
dst_layer_desc
.
data
,
&
dst_iter_desc
.
data
),
"could not create an RNN forward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create an RNN forward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone an src layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
src_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src iter primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_src_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
ldesc
;
const_mkldnn_primitive_desc_t
const_ldesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
ldesc
,
const_ldesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
ldesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last iteration primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
rnn_forward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src_layer
,
const
primitive
::
at
&
src_iter
,
const
primitive
::
at
&
weights_layer
,
const
primitive
::
at
&
weights_iter
,
const
primitive
::
at
&
bias
,
const
memory
&
dst_layer
,
const
memory
&
dst_iter
,
const
memory
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[
5
];
const_mkldnn_primitive_t
outputs
[
3
];
int
idx
=
0
;
inputs
[
idx
++
]
=
src_layer
.
data
;
if
(
!
is_null_memory
(
src_iter
.
data
.
primitive
))
inputs
[
idx
++
]
=
src_iter
.
data
;
inputs
[
idx
++
]
=
weights_layer
.
data
;
inputs
[
idx
++
]
=
weights_iter
.
data
;
if
(
!
is_null_memory
(
bias
.
data
.
primitive
))
inputs
[
idx
++
]
=
bias
.
data
;
idx
=
0
;
outputs
[
idx
++
]
=
dst_layer
.
get
();
if
(
!
is_null_memory
(
dst_iter
.
get
()))
outputs
[
idx
++
]
=
dst_iter
.
get
();
if
(
!
is_null_memory
(
workspace
.
get
()))
outputs
[
idx
++
]
=
workspace
.
get
();
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create an RNN forward primitive"
);
reset
(
result
);
}
};
struct
rnn_backward
:
public
primitive
{
struct
desc
{
mkldnn_rnn_desc_t
data
;
desc
(
prop_kind
aprop_kind
,
rnn_cell
::
desc
cell
,
const
rnn_direction
direction
,
const
memory
::
desc
&
src_layer_desc
,
const
memory
::
desc
&
src_iter_desc
,
const
memory
::
desc
&
weights_layer_desc
,
const
memory
::
desc
&
weights_iter_desc
,
const
memory
::
desc
&
bias_desc
,
const
memory
::
desc
&
dst_layer_desc
,
const
memory
::
desc
&
dst_iter_desc
,
const
memory
::
desc
&
diff_src_layer_desc
,
const
memory
::
desc
&
diff_src_iter_desc
,
const
memory
::
desc
&
diff_weights_layer_desc
,
const
memory
::
desc
&
diff_weights_iter_desc
,
const
memory
::
desc
&
diff_bias_desc
,
const
memory
::
desc
&
diff_dst_layer_desc
,
const
memory
::
desc
&
diff_dst_iter_desc
)
{
error
::
wrap_c_api
(
mkldnn_rnn_backward_desc_init
(
&
data
,
mkldnn
::
convert_to_c
(
aprop_kind
),
cell
,
mkldnn
::
convert_to_c
(
direction
),
&
src_layer_desc
.
data
,
&
src_iter_desc
.
data
,
&
weights_layer_desc
.
data
,
&
weights_iter_desc
.
data
,
&
bias_desc
.
data
,
&
dst_layer_desc
.
data
,
&
dst_iter_desc
.
data
,
&
diff_src_layer_desc
.
data
,
&
diff_src_iter_desc
.
data
,
&
diff_weights_layer_desc
.
data
,
&
diff_weights_iter_desc
.
data
,
&
diff_bias_desc
.
data
,
&
diff_dst_layer_desc
.
data
,
&
diff_dst_iter_desc
.
data
),
"could not create an RNN backward descriptor"
);
}
};
struct
primitive_desc
:
public
handle
<
mkldnn_primitive_desc_t
>
{
primitive_desc
(
const
desc
&
adesc
,
const
engine
&
aengine
)
{
mkldnn_primitive_desc_t
result
;
error
::
wrap_c_api
(
mkldnn_primitive_desc_create
(
&
result
,
&
adesc
.
data
,
aengine
.
get
(),
nullptr
),
"could not create an RNN backward primitive descriptor"
);
reset
(
result
);
}
memory
::
primitive_desc
src_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone an src layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
src_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
src_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src iter primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
weights_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
weights_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
dst_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
dst_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last iteration primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_src_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone an src_layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_src_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_src_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a src iter primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_weights_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a weights primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_bias_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_weights_pd
),
2
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a bias primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_layer_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last layer primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
diff_dst_iter_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
cdesc
;
const_mkldnn_primitive_desc_t
const_cdesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
diff_dst_pd
),
1
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
cdesc
,
const_cdesc
),
"could not clone a dst last iteration primitive descriptor"
);
adesc
.
reset
(
cdesc
);
return
adesc
;
}
memory
::
primitive_desc
workspace_primitive_desc
()
const
{
memory
::
primitive_desc
adesc
;
mkldnn_primitive_desc_t
ldesc
;
const_mkldnn_primitive_desc_t
const_ldesc
=
mkldnn_primitive_desc_query_pd
(
get
(),
mkldnn
::
convert_to_c
(
workspace_pd
),
0
);
error
::
wrap_c_api
(
mkldnn_primitive_desc_clone
(
&
ldesc
,
const_ldesc
),
"could not clone a workspace primitive descriptor"
);
adesc
.
reset
(
ldesc
);
return
adesc
;
}
engine
get_engine
()
{
return
engine
::
query
(
*
this
);
}
};
// With last iteration (with and without input src_iter)
rnn_backward
(
const
primitive_desc
&
aprimitive_desc
,
const
primitive
::
at
&
src_layer
,
const
primitive
::
at
&
src_iter
,
const
primitive
::
at
&
weights_layer
,
const
primitive
::
at
&
weights_iter
,
const
primitive
::
at
&
bias
,
const
primitive
::
at
&
dst_layer
,
const
primitive
::
at
&
dst_iter
,
const
memory
&
diff_src_layer
,
const
memory
&
diff_src_iter
,
const
memory
&
diff_weights_layer
,
const
memory
&
diff_weights_iter
,
const
memory
&
diff_bias
,
const
primitive
::
at
&
diff_dst_layer
,
const
primitive
::
at
&
diff_dst_iter
,
const
primitive
::
at
&
workspace
)
{
mkldnn_primitive_t
result
;
mkldnn_primitive_at_t
inputs
[
10
];
const_mkldnn_primitive_t
outputs
[
5
];
int
idx
=
0
;
inputs
[
idx
]
=
src_layer
.
data
;
if
(
!
is_null_memory
(
src_iter
.
data
.
primitive
))
inputs
[
idx
++
]
=
src_iter
.
data
;
inputs
[
idx
++
]
=
weights_layer
.
data
;
inputs
[
idx
++
]
=
weights_iter
.
data
;
if
(
!
is_null_memory
(
bias
.
data
.
primitive
))
inputs
[
idx
++
]
=
bias
.
data
;
inputs
[
idx
]
=
dst_layer
.
data
;
if
(
!
is_null_memory
(
dst_iter
.
data
.
primitive
))
inputs
[
idx
++
]
=
dst_iter
.
data
;
inputs
[
idx
]
=
diff_dst_layer
.
data
;
if
(
!
is_null_memory
(
diff_dst_iter
.
data
.
primitive
))
inputs
[
idx
++
]
=
diff_dst_iter
.
data
;
inputs
[
idx
]
=
workspace
.
data
;
idx
=
0
;
outputs
[
idx
]
=
diff_src_layer
.
get
();
if
(
!
is_null_memory
(
diff_src_iter
.
get
()))
outputs
[
idx
++
]
=
diff_src_iter
.
get
();
outputs
[
idx
]
=
diff_weights_layer
.
get
();
outputs
[
idx
]
=
diff_weights_iter
.
get
();
if
(
!
is_null_memory
(
diff_bias
.
get
()))
outputs
[
idx
]
=
diff_bias
.
get
();
error
::
wrap_c_api
(
mkldnn_primitive_create
(
&
result
,
aprimitive_desc
.
get
(),
inputs
,
outputs
),
"could not create an RNN backward primitive"
);
reset
(
result
);
}
};
/// @}
/// @} Primitives
/// @addtogroup cpp_api_stream Stream
/// @{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template
<
>
struct
handle_traits
<
mkldnn_stream_t
>
{
static
constexpr
auto
destructor
=
&
mkldnn_stream_destroy
;
};
#endif
struct
stream
:
public
handle
<
mkldnn_stream_t
>
{
using
handle
::
handle
;
enum
kind
{
any
=
mkldnn_stream_kind_t
::
mkldnn_any_stream
,
eager
=
mkldnn_stream_kind_t
::
mkldnn_eager
,
lazy
=
mkldnn_stream_kind_t
::
mkldnn_lazy
};
static
mkldnn_stream_kind_t
convert_to_c
(
kind
akind
)
{
return
static_cast
<
mkldnn_stream_kind_t
>
(
akind
);
}
/// Constructs a stream.
stream
(
kind
akind
)
{
mkldnn_stream_t
astream
;
error
::
wrap_c_api
(
mkldnn_stream_create
(
&
astream
,
convert_to_c
(
akind
)),
"could not create a stream"
);
reset
(
astream
);
}
/// Submits a vector of primitives to a stream for computations.
///
/// @param primitives The vector of primitives to submit.
/// @returns The stream.
stream
&
submit
(
std
::
vector
<
primitive
>
primitives
)
{
// TODO: find a proper way to convert vector<primitive> to
// vector<mkldnn_primitive_t>
if
(
primitives
.
size
()
==
0
)
return
*
this
;
std
::
vector
<
mkldnn_primitive_t
>
c_api_primitives
;
c_api_primitives
.
reserve
(
primitives
.
size
());
auto
convert_to_c
=
[](
primitive
p
)
{
return
p
.
get
();
};
std
::
transform
(
primitives
.
begin
(),
primitives
.
end
(),
std
::
back_inserter
(
c_api_primitives
),
convert_to_c
);
mkldnn_primitive_t
c_api_error_primitive
;
error
::
wrap_c_api
(
mkldnn_stream_submit
(
get
(),
c_api_primitives
.
size
(),
&
c_api_primitives
[
0
],
&
c_api_error_primitive
),
"could not submit primitives to a stream"
,
&
c_api_error_primitive
);
return
*
this
;
}
/// Waits for all computations submitted to the stream to complete.
///
/// @param block Specifies whether the operation should wait indefinitely or
/// return
/// immediately.
/// @returns @c true if all computations completed.
/// @returns @c false if not all computations completed.
bool
wait
(
bool
block
=
true
)
{
mkldnn_primitive_t
c_api_error_primitive
;
mkldnn_status_t
status
=
mkldnn_stream_wait
(
get
(),
block
,
&
c_api_error_primitive
);
if
(
status
!=
mkldnn_success
&&
status
!=
mkldnn_try_again
)
error
::
wrap_c_api
(
status
,
"could not wait on a stream"
,
&
c_api_error_primitive
);
return
(
status
==
mkldnn_success
);
}
stream
&
rerun
()
{
mkldnn_primitive_t
c_api_error_primitive
;
error
::
wrap_c_api
(
mkldnn_stream_rerun
(
get
(),
&
c_api_error_primitive
),
"could not rerun a stream"
,
&
c_api_error_primitive
);
return
*
this
;
}
};
/// @}
/// @} C++ API
}
// namespace mkldnn
#endif
python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py
浏览文件 @
4d2a2e75
...
...
@@ -62,31 +62,31 @@ def train(use_cuda, train_program, save_dirname):
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
trainer
=
fluid
.
Trainer
(
train_func
=
train_program
,
place
=
place
,
optimizer
=
optimizer
)
train_func
=
train_program
,
place
=
place
,
optimizer
=
optimizer
,
parallel
=
True
)
def
event_handler
(
event
):
if
isinstance
(
event
,
fluid
.
EndEpochEvent
):
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
BATCH_SIZE
)
test_metrics
=
trainer
.
test
(
avg_cost
,
acc
=
trainer
.
test
(
reader
=
test_reader
,
feed_order
=
[
'img'
,
'label'
])
avg_cost_set
=
test_metrics
[
0
]
acc_set
=
test_metrics
[
1
]
# get test acc and loss
acc
=
numpy
.
array
(
acc_set
).
mean
()
avg_cost
=
numpy
.
array
(
avg_cost_set
).
mean
()
print
(
"avg_cost: %s"
%
avg_cost
)
print
(
"acc : %s"
%
acc
)
if
float
(
acc
)
>
0.2
:
# Smaller value to increase CI speed
if
acc
>
0.2
:
# Smaller value to increase CI speed
trainer
.
save_params
(
save_dirname
)
else
:
print
(
'BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'
.
format
(
event
.
epoch
+
1
,
float
(
avg_cost
),
float
(
acc
)
))
if
math
.
isnan
(
float
(
avg_cost
)
):
event
.
epoch
+
1
,
avg_cost
,
acc
))
if
math
.
isnan
(
avg_cost
):
sys
.
exit
(
"got NaN loss, training failed."
)
elif
isinstance
(
event
,
fluid
.
EndStepEvent
):
print
(
"Step {0}, Epoch {1} Metrics {2}"
.
format
(
event
.
step
,
event
.
epoch
,
map
(
numpy
.
array
,
event
.
metrics
)))
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
...
...
@@ -131,4 +131,4 @@ def main(use_cuda):
if
__name__
==
'__main__'
:
# for use_cuda in (False, True):
main
(
use_cuda
=
Fals
e
)
main
(
use_cuda
=
Tru
e
)
python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py
浏览文件 @
4d2a2e75
...
...
@@ -55,24 +55,18 @@ def train(use_cuda, train_program, save_dirname):
if
isinstance
(
event
,
fluid
.
EndEpochEvent
):
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
BATCH_SIZE
)
test_metrics
=
trainer
.
test
(
avg_cost
,
acc
=
trainer
.
test
(
reader
=
test_reader
,
feed_order
=
[
'img'
,
'label'
])
avg_cost_set
=
test_metrics
[
0
]
acc_set
=
test_metrics
[
1
]
# get test acc and loss
acc
=
numpy
.
array
(
acc_set
).
mean
()
avg_cost
=
numpy
.
array
(
avg_cost_set
).
mean
()
print
(
"avg_cost: %s"
%
avg_cost
)
print
(
"acc : %s"
%
acc
)
if
float
(
acc
)
>
0.2
:
# Smaller value to increase CI speed
if
acc
>
0.2
:
# Smaller value to increase CI speed
trainer
.
save_params
(
save_dirname
)
else
:
print
(
'BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'
.
format
(
event
.
epoch
+
1
,
float
(
avg_cost
),
float
(
acc
)
))
if
math
.
isnan
(
float
(
avg_cost
)
):
event
.
epoch
+
1
,
avg_cost
,
acc
))
if
math
.
isnan
(
avg_cost
):
sys
.
exit
(
"got NaN loss, training failed."
)
train_reader
=
paddle
.
batch
(
...
...
python/paddle/fluid/trainer.py
浏览文件 @
4d2a2e75
...
...
@@ -20,6 +20,7 @@ import data_feeder
import
contextlib
import
io
import
unique_name
import
parallel_executor
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import
optimizer
as
opt_module
...
...
@@ -48,12 +49,14 @@ class BeginStepEvent(object):
def
__init__
(
self
,
epoch_id
,
step_id
):
self
.
epoch
=
epoch_id
self
.
step
=
step_id
self
.
fetch_metrics
=
True
class
EndStepEvent
(
object
):
def
__init__
(
self
,
epoch_id
,
step_id
):
def
__init__
(
self
,
epoch_id
,
step_id
,
metrics
):
self
.
epoch
=
epoch_id
self
.
step
=
step_id
self
.
metrics
=
metrics
def
check_and_get_place
(
place
):
...
...
@@ -87,12 +90,17 @@ class Trainer(object):
Args:
train_func(callable): A function which will return loss. The loss must be a scalar.
infer_func(callable): A function which will return predict, used to save inference model
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
place: The device place of this trainer.
"""
def
__init__
(
self
,
train_func
,
optimizer
,
param_path
=
None
,
place
=
None
):
def
__init__
(
self
,
train_func
,
optimizer
,
param_path
=
None
,
place
=
None
,
parallel
=
False
):
self
.
parallel
=
parallel
# 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in
# test_word2vec.py
...
...
@@ -106,14 +114,14 @@ class Trainer(object):
with
framework
.
program_guard
(
self
.
train_program
,
self
.
startup_program
):
program_func_outs
=
train_func
()
self
.
t
est
_outputs
=
program_func_outs
if
isinstance
(
self
.
t
rain_func
_outputs
=
program_func_outs
if
isinstance
(
program_func_outs
,
list
)
else
[
program_func_outs
]
self
.
test_program
=
self
.
train_program
.
clone
()
if
not
isinstance
(
optimizer
,
opt_module
.
Optimizer
):
raise
TypeError
(
"The optimizer should be an instance of Optimizer"
)
# The fisrt element of program_func_outs is loss.
loss
=
self
.
t
est
_outputs
[
0
]
loss
=
self
.
t
rain_func
_outputs
[
0
]
optimize_ops
,
params_grads
=
optimizer
.
minimize
(
loss
)
self
.
place
=
check_and_get_place
(
place
)
...
...
@@ -202,12 +210,7 @@ class Trainer(object):
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def
train
(
self
,
num_epochs
,
event_handler
,
reader
,
feed_order
,
parallel
=
False
):
def
train
(
self
,
num_epochs
,
event_handler
,
reader
=
None
,
feed_order
=
None
):
"""
Train the model.
...
...
@@ -215,25 +218,24 @@ class Trainer(object):
num_epochs: The number of epoch. An epoch will process all data in reader
event_handler: The event handler. A function with type (ev:Event)->void
reader:
parallel: True if use multi-CPUs or multi-GPUs
feed_order: Feeding order of reader. None will following the defining
order in program
Returns:
"""
if
parallel
:
raise
NotImplementedError
(
"Parallel Executor version of trainer is not implemented"
)
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
,
""
)
if
training_role
==
"PSERVER"
:
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
exe
.
run
()
return
self
.
_train_by_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
if
self
.
parallel
:
self
.
_train_by_parallel_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
else
:
self
.
_train_by_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
def
test
(
self
,
reader
,
feed_order
):
"""
...
...
@@ -245,7 +247,8 @@ class Trainer(object):
order in program
"""
return
self
.
_test_by_executor
(
reader
,
feed_order
,
self
.
test_outputs
)
return
self
.
_test_by_executor
(
reader
,
feed_order
,
self
.
train_func_outputs
)
def
save_params
(
self
,
param_path
):
# reference: save_persistables in io.py
...
...
@@ -279,13 +282,25 @@ class Trainer(object):
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
exe
=
executor
.
Executor
(
self
.
place
)
for
epoch_id
in
range
(
num_epochs
):
event_handler
(
BeginEpochEvent
(
epoch_id
))
for
step_id
,
data
in
enumerate
(
reader
()):
event_handler
(
BeginStepEvent
(
epoch_id
,
step_id
))
exe
.
run
(
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[])
event_handler
(
EndStepEvent
(
epoch_id
,
step_id
))
event_handler
(
EndEpochEvent
(
epoch_id
))
reader
=
feeder
.
decorate_reader
(
reader
,
multi_devices
=
False
)
self
.
_train_by_any_executor
(
event_handler
,
exe
,
num_epochs
,
reader
)
def
_train_by_any_executor
(
self
,
event_handler
,
exe
,
num_epochs
,
reader
):
for
epoch_id
in
range
(
num_epochs
):
event_handler
(
BeginEpochEvent
(
epoch_id
))
for
step_id
,
data
in
enumerate
(
reader
()):
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
event_handler
(
begin_event
)
if
begin_event
.
fetch_metrics
:
metrics
=
exe
.
run
(
feed
=
data
,
fetch_list
=
[
var
.
name
for
var
in
self
.
train_func_outputs
])
else
:
metrics
=
exe
.
run
(
feed
=
data
,
fetch_list
=
[])
event_handler
(
EndStepEvent
(
epoch_id
,
step_id
,
metrics
))
event_handler
(
EndEpochEvent
(
epoch_id
))
def
_test_by_executor
(
self
,
reader
,
feed_order
,
fetch_list
):
with
executor
.
scope_guard
(
self
.
scope
):
...
...
@@ -304,6 +319,28 @@ class Trainer(object):
return
[
x
/
count
for
x
in
accumulated
]
def
_train_by_parallel_executor
(
self
,
num_epochs
,
event_handler
,
reader
,
feed_order
):
with
self
.
_prog_and_scope_guard
():
pe
=
self
.
_get_or_create_parallel_executor
()
feed_var_list
=
build_feed_var_list
(
self
.
train_program
,
feed_order
)
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
reader
=
feeder
.
decorate_reader
(
reader
,
multi_devices
=
True
)
for
epoch_id
in
range
(
num_epochs
):
self
.
_train_by_any_executor
(
event_handler
,
pe
,
num_epochs
,
reader
)
def
_get_parallel_executor
(
self
):
return
getattr
(
self
,
'parallel_executor'
,
None
)
def
_get_or_create_parallel_executor
(
self
):
if
self
.
_get_parallel_executor
()
is
None
:
self
.
parallel_executor
=
parallel_executor
.
ParallelExecutor
(
use_cuda
=
isinstance
(
self
.
place
,
core
.
CUDAPlace
),
loss_name
=
self
.
train_func_outputs
[
0
].
name
)
return
self
.
_get_parallel_executor
()
def
build_feed_var_list
(
program
,
feed_order
):
if
not
isinstance
(
program
,
framework
.
Program
):
...
...
tools/timeline.py
浏览文件 @
4d2a2e75
...
...
@@ -171,7 +171,7 @@ if args.timeline_path:
profile_paths
=
profile_path
.
split
(
','
)
profile_dict
=
dict
()
if
len
(
profile_path
)
==
1
:
if
len
(
profile_path
s
)
==
1
:
with
open
(
profile_path
,
'r'
)
as
f
:
profile_s
=
f
.
read
()
profile_pb
=
profiler_pb2
.
Profile
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录