Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
428b2b9e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
428b2b9e
编写于
9月 10, 2019
作者:
A
Adam
提交者:
Tao Luo
9月 10, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN handler cleanup (#19713)
* MKLDNN handler cleanup * MKLDNN handler cleanup test=develop
上级
2c30e64b
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
115 addition
and
137 deletion
+115
-137
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
...operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
+1
-1
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
+7
-9
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+2
-2
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
+4
-4
paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc
paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc
+4
-4
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
+2
-2
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+28
-0
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+67
-115
未找到文件。
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
浏览文件 @
428b2b9e
...
...
@@ -136,7 +136,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
std
::
vector
<
memory
::
primitive_desc
>
srcs_pd
;
std
::
vector
<
float
>
scales
=
{
1.0
f
,
1.0
f
};
const
std
::
string
key
=
platform
::
MKLDNNHandler
::
GetHash
(
const
std
::
string
key
=
platform
::
GetHash
(
src_x_tz
,
ctx
.
op
().
Output
(
"Out"
)
+
std
::
to_string
(
x
->
format
())
+
std
::
to_string
(
y
->
format
()));
...
...
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
浏览文件 @
428b2b9e
...
...
@@ -72,19 +72,17 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
std
::
string
key
;
key
.
reserve
(
platform
::
MKLDNNHandler
::
MaxKeyLength
);
for
(
size_t
i
=
0
;
i
<
multi_input
.
size
();
i
++
)
{
platform
::
MKLDNNHandler
::
AppendKeyDims
(
platform
::
AppendKeyDims
(
&
key
,
paddle
::
framework
::
vectorize
<
int
>
(
multi_input
[
i
]
->
dims
()));
}
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
concat_axis
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
ctx
.
op
().
Output
(
"Out"
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
dt
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
multi_input
[
0
]
->
format
()));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
concat_axis
));
platform
::
AppendKey
(
&
key
,
ctx
.
op
().
Output
(
"Out"
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
dt
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
multi_input
[
0
]
->
format
()));
if
(
platform
::
get_cur_mkldnn_session_id
()
==
platform
::
kMKLDNNSessionID_Default
)
{
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
"-t:"
);
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
platform
::
MKLDNNHandler
::
ThreadIDasStr
());
platform
::
AppendKey
(
&
key
,
"-t:"
);
platform
::
AppendKey
(
&
key
,
platform
::
ThreadIDasStr
());
}
return
key
;
}
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
428b2b9e
...
...
@@ -417,7 +417,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives
std
::
string
key
;
key
.
reserve
(
MaxKeyLength
);
platform
::
ConvMKLDNNHandler
::
Append
Key
(
platform
::
ConvMKLDNNHandler
::
Create
Key
(
&
key
,
src_tz
,
weights_tz
,
strides
,
paddings
,
dilations
,
groups
,
src_dt
,
input
->
format
(),
fuse_activation
,
fuse_residual_conn
,
ctx
.
op
().
Input
(
"Input"
)
+
ctx
.
op
().
Input
(
"Filter"
));
...
...
@@ -439,7 +439,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
string
key_tid
=
""
;
if
(
platform
::
get_cur_mkldnn_session_id
()
==
platform
::
kMKLDNNSessionID_Default
)
{
key_tid
=
"-t:"
+
platform
::
MKLDNNHandler
::
ThreadIDasStr
();
key_tid
=
"-t:"
+
platform
::
ThreadIDasStr
();
}
auto
prim_key
=
key
+
key_tid
+
"@conv_p"
;
...
...
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
浏览文件 @
428b2b9e
...
...
@@ -36,10 +36,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const
std
::
vector
<
int
>&
src_tz
,
const
float
scale_data
)
{
std
::
string
key
;
key
.
reserve
(
platform
::
MKLDNNHandler
::
MaxKeyLength
);
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
src_dt
));
platform
::
MKLDNNHandler
::
AppendKeyDims
(
&
key
,
src_tz
);
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
scale_data
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
ctx
.
op
().
Output
(
"Output"
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
src_dt
));
platform
::
AppendKeyDims
(
&
key
,
src_tz
);
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
scale_data
));
platform
::
AppendKey
(
&
key
,
ctx
.
op
().
Output
(
"Output"
));
return
key
;
}
...
...
paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc
浏览文件 @
428b2b9e
...
...
@@ -35,10 +35,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const
bool
is_negative
)
{
std
::
string
key
;
key
.
reserve
(
platform
::
MKLDNNHandler
::
MaxKeyLength
);
platform
::
MKLDNNHandler
::
AppendKeyDims
(
&
key
,
src_tz
);
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
scale_data
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
is_negative
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
ctx
.
op
().
Output
(
"Output"
));
platform
::
AppendKeyDims
(
&
key
,
src_tz
);
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
scale_data
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
is_negative
));
platform
::
AppendKey
(
&
key
,
ctx
.
op
().
Output
(
"Output"
));
return
key
;
}
...
...
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
浏览文件 @
428b2b9e
...
...
@@ -205,7 +205,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
memory
::
dims
softmax_tz
=
{
src_tz
[
0
],
src_tz
[
1
]};
// Generate keys for storing/retriving primitives for this operator
const
std
::
string
key
=
platform
::
MKLDNNHandler
::
GetHash
(
softmax_tz
,
ctx
.
op
().
Output
(
"Out"
));
platform
::
GetHash
(
softmax_tz
,
ctx
.
op
().
Output
(
"Out"
));
SoftmaxMKLDNNHandler
<
T
>
handler
(
softmax_tz
,
MKLDNNMemoryFormat
::
nc
,
dev_ctx
,
mkldnn_engine
,
key
);
...
...
@@ -276,7 +276,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
// Currently only supports NC data format
// retrieve eltwise primitive desc from device context
const
std
::
string
key
=
platform
::
MKLDNNHandler
::
GetHash
(
softmax_tz
,
ctx
.
op
().
Input
(
"Out"
));
platform
::
GetHash
(
softmax_tz
,
ctx
.
op
().
Input
(
"Out"
));
const
std
::
string
key_softmax_pd
=
key
+
"@softmax_pd"
;
auto
softmax_pd
=
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
428b2b9e
...
...
@@ -179,5 +179,33 @@ inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
}
}
inline
std
::
string
ThreadIDasStr
(
void
)
{
return
std
::
to_string
(
std
::
hash
<
std
::
thread
::
id
>
()(
std
::
this_thread
::
get_id
()));
}
inline
std
::
string
dims2str
(
const
mkldnn
::
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
}
inline
void
AppendKey
(
std
::
string
*
key
,
const
std
::
string
&
s
)
{
key
->
append
(
s
);
}
inline
std
::
string
GetHash
(
const
mkldnn
::
memory
::
dims
&
operand_dims
,
const
std
::
string
&
suffix
)
{
return
dims2str
(
operand_dims
)
+
suffix
;
}
inline
void
AppendKeyDims
(
std
::
string
*
key
,
const
mkldnn
::
memory
::
dims
&
dims
)
{
for
(
unsigned
int
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
AppendKey
(
key
,
std
::
to_string
(
dims
[
i
]));
}
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
428b2b9e
...
...
@@ -38,7 +38,7 @@ class MKLDNNHandler {
platform
::
kMKLDNNSessionID_Default
)
{
key_
=
key_common_
;
}
else
{
key_
=
key_common_
+
"-t:"
+
MKLDNNHandler
::
ThreadIDasStr
();
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
}
}
...
...
@@ -47,35 +47,19 @@ class MKLDNNHandler {
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSecondSrcMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_src2_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWeightsMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
,
user_function
custom_func
=
{})
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_weights_mem_p"
,
custom_func
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireBiasMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_bias_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiff
Dst
Memory
(
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiff
Src
Memory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_diff_
dst
_mem_p"
);
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_diff_
src
_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiff
Src
Memory
(
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiff
Dst
Memory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_diff_
src
_mem_p"
);
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_diff_
dst
_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryFromPrimitive
(
...
...
@@ -138,18 +122,6 @@ class MKLDNNHandler {
return
mem_p
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemory
(
const
mkldnn
::
memory
::
primitive_desc
&
mpd
,
const
std
::
string
&
suffix
)
{
auto
local_key
=
key_
+
suffix
;
auto
mem_p
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx_
.
GetBlob
(
local_key
));
if
(
mem_p
==
nullptr
)
{
mem_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mpd
);
dev_ctx_
.
SetBlob
(
local_key
,
mem_p
);
}
return
mem_p
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemory
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
user_memory_p
,
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
target_memory_p
,
...
...
@@ -221,67 +193,6 @@ class MKLDNNHandler {
return
target_memory_p
;
}
static
std
::
string
ThreadIDasStr
(
void
)
{
return
std
::
to_string
(
std
::
hash
<
std
::
thread
::
id
>
()(
std
::
this_thread
::
get_id
()));
}
static
std
::
string
GetHash
(
mkldnn
::
memory
::
dims
&
operand_dims
,
// NOLINT
const
std
::
string
&
suffix
)
{
return
dims2str
(
operand_dims
)
+
suffix
;
}
static
void
AppendKey
(
std
::
string
*
key
,
const
mkldnn
::
memory
::
dims
&
input_dims
,
const
mkldnn
::
memory
::
dims
&
weights_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
&
groups
,
const
mkldnn
::
memory
::
data_type
&
srcdt
,
const
MKLDNNMemoryFormat
&
format
,
const
std
::
string
&
fuse_activation
,
const
bool
&
residual
,
const
std
::
string
&
suffix
)
{
AppendKeyDims
(
key
,
input_dims
);
AppendKeyDims
(
key
,
weights_dims
);
AppendKeyVec
(
key
,
strides
);
AppendKeyVec
(
key
,
paddings
);
AppendKeyVec
(
key
,
dilations
);
AppendKey
(
key
,
std
::
to_string
(
groups
));
AppendKey
(
key
,
std
::
to_string
(
srcdt
));
AppendKey
(
key
,
std
::
to_string
(
format
));
AppendKey
(
key
,
fuse_activation
);
AppendKey
(
key
,
std
::
to_string
(
residual
));
AppendKey
(
key
,
suffix
);
}
static
void
AppendKeyDims
(
std
::
string
*
key
,
const
mkldnn
::
memory
::
dims
&
dims
)
{
for
(
unsigned
int
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
AppendKey
(
key
,
std
::
to_string
(
dims
[
i
]));
}
}
static
void
AppendKeyVec
(
std
::
string
*
key
,
const
std
::
vector
<
int
>&
dims
)
{
for
(
unsigned
int
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
AppendKey
(
key
,
std
::
to_string
(
dims
[
i
]));
}
}
static
void
AppendKey
(
std
::
string
*
key
,
const
std
::
string
&
s
)
{
key
->
append
(
s
);
}
protected:
static
std
::
string
dims2str
(
const
mkldnn
::
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
}
protected:
const
MKLDNNDeviceContext
&
dev_ctx_
;
mkldnn
::
engine
engine_
;
...
...
@@ -324,6 +235,11 @@ class SumMKLDNNHandler : public MKLDNNHandler {
"@dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSecondSrcMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_src2_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
sum
>
AcquireSum
(
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory
,
std
::
vector
<
mkldnn
::
primitive
::
at
>*
inputs
)
{
...
...
@@ -458,12 +374,12 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
const
float
beta
,
const
std
::
string
&
suffix
)
{
std
::
string
key
;
key
.
reserve
(
platform
::
MKLDNNHandler
::
MaxKeyLength
);
platform
::
MKLDNNHandler
::
AppendKeyDims
(
&
key
,
input_dims
);
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
algorithm
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
fmt
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
alpha
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
beta
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
suffix
);
platform
::
AppendKeyDims
(
&
key
,
input_dims
);
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
algorithm
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
fmt
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
alpha
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
beta
));
platform
::
AppendKey
(
&
key
,
suffix
);
return
key
;
}
...
...
@@ -609,13 +525,13 @@ class LRNMKLDNNHandler : public MKLDNNHandler {
const
std
::
string
&
suffix
)
{
std
::
string
key
;
key
.
reserve
(
platform
::
MKLDNNHandler
::
MaxKeyLength
);
platform
::
MKLDNNHandler
::
AppendKeyDims
(
&
key
,
input_dims
);
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
n
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
alpha
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
beta
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
k
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
fmt
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
suffix
);
platform
::
AppendKeyDims
(
&
key
,
input_dims
);
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
n
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
alpha
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
beta
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
k
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
fmt
));
platform
::
AppendKey
(
&
key
,
suffix
);
return
key
;
}
...
...
@@ -803,14 +719,14 @@ class PoolingMKLDNNHandler : public MKLDNNHandler {
const
MKLDNNMemoryFormat
&
fmt
,
const
std
::
string
&
suffix
)
{
std
::
string
key
;
key
.
reserve
(
platform
::
MKLDNNHandler
::
MaxKeyLength
);
platform
::
MKLDNNHandler
::
AppendKeyDims
(
&
key
,
input_dims
);
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
pooling_type
);
platform
::
MKLDNNHandler
::
AppendKeyVec
(
&
key
,
ksize
);
platform
::
MKLDNNHandler
::
AppendKeyVec
(
&
key
,
strides
);
platform
::
MKLDNNHandler
::
AppendKeyVec
(
&
key
,
paddings
);
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
dt
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
std
::
to_string
(
fmt
));
platform
::
MKLDNNHandler
::
AppendKey
(
&
key
,
suffix
);
platform
::
AppendKeyDims
(
&
key
,
input_dims
);
platform
::
AppendKey
(
&
key
,
pooling_type
);
platform
::
AppendKeyDims
(
&
key
,
ksize
);
platform
::
AppendKeyDims
(
&
key
,
strides
);
platform
::
AppendKeyDims
(
&
key
,
paddings
);
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
dt
));
platform
::
AppendKey
(
&
key
,
std
::
to_string
(
fmt
));
platform
::
AppendKey
(
&
key
,
suffix
);
return
key
;
}
...
...
@@ -1160,6 +1076,17 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
pipeline
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWeightsMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
,
user_function
custom_func
=
{})
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_weights_mem_p"
,
custom_func
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireBiasMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_bias_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWeightsMemoryFromPrimitive
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>
user_weights_memory_p
,
std
::
vector
<
mkldnn
::
primitive
>&
pipeline
,
// NOLINT
...
...
@@ -1368,6 +1295,31 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
suffix
;
}
static
void
CreateKey
(
std
::
string
*
key
,
const
mkldnn
::
memory
::
dims
&
input_dims
,
const
mkldnn
::
memory
::
dims
&
weights_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
&
groups
,
const
mkldnn
::
memory
::
data_type
&
srcdt
,
const
MKLDNNMemoryFormat
&
format
,
const
std
::
string
&
fuse_activation
,
const
bool
&
residual
,
const
std
::
string
&
suffix
)
{
AppendKeyDims
(
key
,
input_dims
);
AppendKeyDims
(
key
,
weights_dims
);
AppendKeyDims
(
key
,
strides
);
AppendKeyDims
(
key
,
paddings
);
AppendKeyDims
(
key
,
dilations
);
AppendKey
(
key
,
std
::
to_string
(
groups
));
AppendKey
(
key
,
std
::
to_string
(
srcdt
));
AppendKey
(
key
,
std
::
to_string
(
format
));
AppendKey
(
key
,
fuse_activation
);
AppendKey
(
key
,
std
::
to_string
(
residual
));
AppendKey
(
key
,
suffix
);
}
private:
std
::
shared_ptr
<
typename
forward_t
::
primitive_desc
>
conv_pd_
;
std
::
shared_ptr
<
typename
backward_weights_t
::
primitive_desc
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录