Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c981222b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c981222b
编写于
6年前
作者:
J
Jacek Czaja
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
- Conv MKLDNN grad op reuse of mkldnn primitives
上级
f0cd493c
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
212 addition
and
133 deletion
+212
-133
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+212
-133
未找到文件。
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
c981222b
...
...
@@ -18,9 +18,6 @@
namespace
paddle
{
namespace
operators
{
using
conv_bwd_data
=
mkldnn
::
convolution_backward_data
;
using
conv_bwd_weights
=
mkldnn
::
convolution_backward_weights
;
using
conv_fwd
=
mkldnn
::
convolution_forward
;
using
framework
::
DataLayout
;
using
mkldnn
::
memory
;
using
mkldnn
::
primitive
;
...
...
@@ -39,6 +36,72 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
conv_pd_
=
conv_pd
;
}
ConvMKLDNNHandler
(
std
::
shared_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>
conv_pd
,
std
::
shared_ptr
<
mkldnn
::
convolution_backward_data
::
primitive_desc
>
conv_bwd_data_pd
,
std
::
shared_ptr
<
mkldnn
::
convolution_backward_weights
::
primitive_desc
>
conv_bwd_weights_pd
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
),
conv_pd_
(
conv_pd
),
conv_bwd_weights_pd_
(
conv_bwd_weights_pd
),
conv_bwd_data_pd_
(
conv_bwd_data_pd
)
{
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
key_
+=
"-BWD"
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemoryFromWeightsPrimitive
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>
user_memory_p
,
std
::
vector
<
mkldnn
::
primitive
>&
pipeline
)
{
auto
src_pd
=
conv_bwd_weights_pd_
->
src_primitive_desc
();
auto
user_pd
=
user_memory_p
->
get_primitive_desc
();
return
this
->
AcquireMemory
(
src_pd
,
user_pd
,
user_memory_p
,
"@weights-src_mem_p"
,
pipeline
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffDstMemoryFromWeightsPrimitive
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>
user_memory_p
,
std
::
vector
<
mkldnn
::
primitive
>&
pipeline
)
{
auto
diff_dst_pd
=
conv_bwd_weights_pd_
->
diff_dst_primitive_desc
();
auto
user_pd
=
user_memory_p
->
get_primitive_desc
();
return
this
->
AcquireMemory
(
diff_dst_pd
,
user_pd
,
user_memory_p
,
"@weights-diff_dst_mem_p"
,
pipeline
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffWeightsMemoryFromWeightsPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
conv_bwd_weights_pd_
->
diff_weights_primitive_desc
(),
ptr
,
"@diff_weights_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffDstMemoryFromDataPrimitive
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>
user_memory_p
,
std
::
vector
<
mkldnn
::
primitive
>&
pipeline
)
{
auto
diff_dst_pd
=
conv_bwd_data_pd_
->
diff_dst_primitive_desc
();
auto
user_pd
=
user_memory_p
->
get_primitive_desc
();
return
this
->
AcquireMemory
(
diff_dst_pd
,
user_pd
,
user_memory_p
,
"@data-diff_dst_mem_p"
,
pipeline
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWeightsMemoryFromDataPrimitive
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>
user_weights_memory_p
,
std
::
vector
<
mkldnn
::
primitive
>&
pipeline
)
{
auto
weights_pd
=
conv_bwd_data_pd_
->
weights_primitive_desc
();
auto
user_pd
=
user_weights_memory_p
->
get_primitive_desc
();
return
this
->
AcquireMemory
(
weights_pd
,
user_pd
,
user_weights_memory_p
,
"@data-weights_mem_p"
,
pipeline
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemoryFromDataPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
conv_bwd_data_pd_
->
diff_src_primitive_desc
(),
ptr
,
"@diff_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
conv_pd_
->
dst_primitive_desc
(),
ptr
,
"@dst_mem_p"
);
...
...
@@ -68,7 +131,6 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std
::
shared_ptr
<
mkldnn
::
memory
>
weights_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
)
{
auto
prim_key
=
key_
+
"@conv_p"
;
auto
prim_desc_key
=
key_
+
"@conv_pd"
;
auto
conv_p
=
std
::
static_pointer_cast
<
mkldnn
::
convolution_forward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
((
conv_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
...
...
@@ -85,6 +147,54 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
return
conv_p
;
}
std
::
shared_ptr
<
mkldnn
::
convolution_backward_weights
>
AcquireConvolutionBackwardWeights
(
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_weights_memory_p
)
{
auto
prim_key
=
key_
+
"@conv_bwd_weights_p"
;
auto
conv_bwd_weights_p
=
std
::
static_pointer_cast
<
mkldnn
::
convolution_backward_weights
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
(
(
conv_bwd_weights_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find convolution bwd weights primitive in device context"
);
if
(
conv_bwd_weights_p
==
nullptr
)
{
// create backward conv primitive for weights
conv_bwd_weights_p
=
std
::
make_shared
<
mkldnn
::
convolution_backward_weights
>
(
*
conv_bwd_weights_pd_
,
*
src_memory_p
,
*
diff_dst_memory_p
,
*
diff_weights_memory_p
);
dev_ctx_
.
SetBlob
(
prim_key
,
conv_bwd_weights_p
);
}
else
{
is_reusing_
=
true
;
}
return
conv_bwd_weights_p
;
}
std
::
shared_ptr
<
mkldnn
::
convolution_backward_data
>
AcquireConvolutionBackwardData
(
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
weights_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_src_memory_p
)
{
auto
prim_key
=
key_
+
"@conv_bwd_data_p"
;
auto
conv_bwd_data_p
=
std
::
static_pointer_cast
<
mkldnn
::
convolution_backward_data
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
(
(
conv_bwd_data_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find convolution bwd data primitive in device context"
);
if
(
conv_bwd_data_p
==
nullptr
)
{
conv_bwd_data_p
=
std
::
make_shared
<
mkldnn
::
convolution_backward_data
>
(
*
conv_bwd_data_pd_
,
*
diff_dst_memory_p
,
*
weights_memory_p
,
*
diff_src_memory_p
);
dev_ctx_
.
SetBlob
(
prim_key
,
conv_bwd_data_p
);
}
else
{
is_reusing_
=
true
;
}
return
conv_bwd_data_p
;
}
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static
std
::
string
GetHash
(
memory
::
dims
&
input_dims
,
...
...
@@ -100,6 +210,10 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
private:
std
::
shared_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>
conv_pd_
;
std
::
shared_ptr
<
mkldnn
::
convolution_backward_weights
::
primitive_desc
>
conv_bwd_weights_pd_
;
std
::
shared_ptr
<
mkldnn
::
convolution_backward_data
::
primitive_desc
>
conv_bwd_data_pd_
;
};
template
<
typename
T
>
...
...
@@ -174,8 +288,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
memory
::
format
::
any
);
// create a conv primitive descriptor and save it for usage in backward
std
::
shared_ptr
<
conv_fwd
::
primitive_desc
>
conv_pd
=
ConvFwdPrimitiveDesc
(
src_md
,
weights_md
,
dst_md
,
strides
,
paddings
,
mkldnn_engine
);
std
::
shared_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>
conv_pd
=
ConvFwdPrimitiveDesc
(
src_md
,
weights_md
,
dst_md
,
strides
,
paddings
,
mkldnn_engine
);
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx
.
SetBlob
(
key_conv_pd
,
conv_pd
);
...
...
@@ -208,21 +323,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
private:
std
::
unique_ptr
<
conv_fwd
::
primitive_desc
>
ConvFwdPrimitiveDesc
(
const
memory
::
desc
&
src
,
const
memory
::
desc
&
weights
,
std
::
unique_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>
ConvFwdPrimitiveDesc
(
const
memory
::
desc
&
src
,
const
memory
::
desc
&
weights
,
const
memory
::
desc
&
dst
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
mkldnn
::
engine
&
engine
)
const
{
const
std
::
vector
<
int
>&
paddings
,
const
mkldnn
::
engine
&
engine
)
const
{
memory
::
dims
stride_dims
=
{
strides
[
0
],
strides
[
1
]};
memory
::
dims
padding_dims
=
{
paddings
[
0
],
paddings
[
1
]};
auto
conv_desc
=
conv_fwd
::
desc
(
mkldnn
::
prop_kind
::
forward
,
mkldnn
::
convolution_direct
,
src
,
weights
,
dst
,
stride
_dims
,
padding_dims
,
padding_dims
,
mkldnn
::
padding_kind
::
zero
);
auto
conv_desc
=
mkldnn
::
convolution_forward
::
desc
(
mkldnn
::
prop_kind
::
forward
,
mkldnn
::
convolution_direct
,
src
,
weights
,
dst
,
stride_dims
,
padding
_dims
,
padding_dims
,
mkldnn
::
padding_kind
::
zero
);
auto
p_conv_pd
=
new
conv_fwd
::
primitive_desc
(
conv_desc
,
engine
);
auto
p_conv_pd
=
new
mkldnn
::
convolution_forward
::
primitive_desc
(
conv_desc
,
engine
);
return
std
::
unique_ptr
<
conv_fwd
::
primitive_desc
>
(
p_conv_pd
);
return
std
::
unique_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>
(
p_conv_pd
);
}
};
...
...
@@ -290,147 +408,108 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dilations
,
groups
,
ctx
.
op
().
Input
(
"Output"
));
const
std
::
string
key_conv_pd
=
key
+
"@conv_pd"
;
std
::
vector
<
primitive
>
pipeline
;
// create mkldnn memory from input tensors (input/weights/output_grad)
auto
user_src_memory
=
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input
->
format
()},
mkldnn_engine
},
to_void_cast
(
input_data
));
auto
user_weights_memory
=
memory
({{{
weights_tz
},
memory
::
data_type
::
f32
,
filter
->
format
()},
mkldnn_engine
},
to_void_cast
(
filter_data
));
auto
user_diff_dst_memory
=
memory
({{{
dst_tz
},
memory
::
data_type
::
f32
,
output_grad
->
format
()},
mkldnn_engine
},
to_void_cast
(
output_grad_data
));
// Create user memory descriptors
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
input
->
format
());
auto
user_weights_md
=
platform
::
MKLDNNMemDesc
(
{
weights_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
filter
->
format
());
auto
user_diff_dst_md
=
platform
::
MKLDNNMemDesc
(
{
dst_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
output_grad
->
format
());
/* create memory descriptor for conv backward without specified format
* ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance
*/
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
memory
::
data_type
::
f32
,
memory
::
format
::
any
);
auto
diff_src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
memory
::
data_type
::
f32
,
memory
::
format
::
any
);
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
memory
::
format
::
any
);
auto
diff_src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
memory
::
format
::
any
);
auto
weights_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
memory
::
data_type
::
f32
,
memory
::
format
::
any
);
weights_tz
,
platform
::
MKLDNNGetDataType
<
T
>
()
,
memory
::
format
::
any
);
auto
diff_weights_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
memory
::
data_type
::
f32
,
memory
::
format
::
any
);
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
memory
::
data_type
::
f32
,
memory
::
format
::
any
);
weights_tz
,
platform
::
MKLDNNGetDataType
<
T
>
()
,
memory
::
format
::
any
);
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
memory
::
format
::
any
);
// Retrieve conv_pd from device context
auto
conv_pd
=
std
::
static_pointer_cast
<
conv_fwd
::
primitive_desc
>
(
auto
conv_pd
=
std
::
static_pointer_cast
<
mkldnn
::
convolution_forward
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_conv_pd
));
PADDLE_ENFORCE
(
conv_pd
!=
nullptr
,
"Fail to find conv_pd in device context"
);
// create backward conv primitive for weights
if
(
filter_grad
)
{
// create backward convolution primitive descriptor
auto
conv_bwd_weights_desc
=
conv_bwd_weights
::
desc
(
// create backward convolution weights primitive descriptor
auto
conv_bwd_weights_desc
=
mkldnn
::
convolution_backward_weights
::
desc
(
mkldnn
::
convolution_direct
,
src_md
,
diff_weights_md
,
diff_dst_md
,
strides
,
paddings
,
paddings
,
mkldnn
::
padding_kind
::
zero
);
auto
conv_bwd_weights_pd
=
conv_bwd_weights
::
primitive_desc
(
auto
conv_bwd_weights_pd
=
std
::
make_shared
<
mkldnn
::
convolution_backward_weights
::
primitive_desc
>
(
conv_bwd_weights_desc
,
mkldnn_engine
,
*
conv_pd
);
// create reorder primitive if the input format is not the preferred one
auto
src_memory
=
user_src_memory
;
primitive
reorder_src
;
bool
is_src_reordered
=
false
;
if
(
memory
::
primitive_desc
(
conv_bwd_weights_pd
.
src_primitive_desc
())
!=
user_src_memory
.
get_primitive_desc
())
{
src_memory
=
memory
(
conv_bwd_weights_pd
.
src_primitive_desc
());
reorder_src
=
reorder
(
user_src_memory
,
src_memory
);
is_src_reordered
=
true
;
}
// create backward convolution data primitive descriptor
auto
conv_bwd_data_desc
=
mkldnn
::
convolution_backward_data
::
desc
(
mkldnn
::
convolution_direct
,
diff_src_md
,
weights_md
,
diff_dst_md
,
strides
,
paddings
,
paddings
,
mkldnn
::
padding_kind
::
zero
);
auto
conv_bwd_data_pd
=
std
::
make_shared
<
mkldnn
::
convolution_backward_data
::
primitive_desc
>
(
conv_bwd_data_desc
,
mkldnn_engine
,
*
conv_pd
);
auto
diff_dst_memory_4filter
=
user_diff_dst_memory
;
primitive
reorder_diff_dst_4filter
;
bool
is_diff_dst_reordered_4filter
=
false
;
if
(
memory
::
primitive_desc
(
conv_bwd_weights_pd
.
diff_dst_primitive_desc
())
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory_4filter
=
memory
(
conv_bwd_weights_pd
.
diff_dst_primitive_desc
());
reorder_diff_dst_4filter
=
reorder
(
user_diff_dst_memory
,
diff_dst_memory_4filter
);
is_diff_dst_reordered_4filter
=
true
;
}
ConvMKLDNNHandler
handler
(
conv_pd
,
conv_bwd_data_pd
,
conv_bwd_weights_pd
,
dev_ctx
,
mkldnn_engine
,
key
);
// create mkldnn memory for output (i.e. diff weights)
auto
diff_weights_memory
=
memory
(
conv_bwd_weights_pd
.
diff_weights_primitive_desc
(),
reinterpret_cast
<
void
*>
(
filter_grad_data
));
// create mkldnn memory from input tensors (data/weights)
auto
user_src_memory_p
=
handler
.
AcquireSrcMemory
(
user_src_md
,
to_void_cast
<
T
>
(
input_data
));
auto
user_weights_memory_p
=
handler
.
AcquireWeightsMemory
(
user_weights_md
,
to_void_cast
<
T
>
(
filter_data
));
auto
user_diff_dst_memory_p
=
handler
.
AcquireDiffDstMemory
(
user_diff_dst_md
,
to_void_cast
<
T
>
(
output_grad_data
));
// create backward conv primitive for weights
auto
conv_bwd_weights_prim
=
conv_bwd_weights
(
conv_bwd_weights_pd
,
src_memory
,
diff_dst_memory_4filter
,
diff_weights_memory
);
if
(
filter_grad
)
{
auto
src_memory_p
=
handler
.
AcquireSrcMemoryFromWeightsPrimitive
(
user_src_memory_p
,
pipeline
);
// push primitive and execute it
std
::
vector
<
primitive
>
pipeline
;
if
(
is_src_reordered
)
pipeline
.
push_back
(
reorder_src
);
if
(
is_diff_dst_reordered_4filter
)
pipeline
.
push_back
(
reorder_diff_dst_4filter
);
pipeline
.
push_back
(
conv_bwd_weights_prim
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
auto
diff_dst_memory_4filter_p
=
handler
.
AcquireDiffDstMemoryFromWeightsPrimitive
(
user_diff_dst_memory_p
,
pipeline
);
auto
diff_weights_memory_p
=
handler
.
AcquireDiffWeightsMemoryFromWeightsPrimitive
(
reinterpret_cast
<
void
*>
(
filter_grad_data
));
auto
conv_bwd_weights_p
=
handler
.
AcquireConvolutionBackwardWeights
(
src_memory_p
,
diff_dst_memory_4filter_p
,
diff_weights_memory_p
);
// push primitive to stream and wait until it's executed
pipeline
.
push_back
(
*
conv_bwd_weights_p
);
filter_grad
->
set_layout
(
DataLayout
::
kMKLDNN
);
filter_grad
->
set_format
(
GetMKLDNNFormat
(
diff_weights_memory
));
filter_grad
->
set_format
(
GetMKLDNNFormat
(
*
diff_weights_memory_p
));
}
if
(
input_grad
)
{
// create backward convolution primitive descriptor
auto
conv_bwd_data_desc
=
conv_bwd_data
::
desc
(
mkldnn
::
convolution_direct
,
diff_src_md
,
weights_md
,
diff_dst_md
,
strides
,
paddings
,
paddings
,
mkldnn
::
padding_kind
::
zero
);
auto
conv_bwd_data_pd
=
conv_bwd_data
::
primitive_desc
(
conv_bwd_data_desc
,
mkldnn_engine
,
*
conv_pd
);
// create reorder primitive if the input format is not the preferred one
auto
weights_memory
=
user_weights_memory
;
primitive
reorder_weights
;
bool
is_weights_reordered
=
false
;
if
(
memory
::
primitive_desc
(
conv_bwd_data_pd
.
weights_primitive_desc
())
!=
user_weights_memory
.
get_primitive_desc
())
{
weights_memory
=
memory
(
conv_bwd_data_pd
.
weights_primitive_desc
());
reorder_weights
=
reorder
(
user_weights_memory
,
weights_memory
);
is_weights_reordered
=
true
;
}
auto
weights_memory_p
=
handler
.
AcquireWeightsMemoryFromDataPrimitive
(
user_weights_memory_p
,
pipeline
);
auto
diff_dst_memory_4data
=
user_diff_dst_memory
;
primitive
reorder_diff_dst_4data
;
bool
is_diff_dst_reordered_4data
=
false
;
if
(
memory
::
primitive_desc
(
conv_bwd_data_pd
.
diff_dst_primitive_desc
())
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory_4data
=
memory
(
conv_bwd_data_pd
.
diff_dst_primitive_desc
());
reorder_diff_dst_4data
=
reorder
(
user_diff_dst_memory
,
diff_dst_memory_4data
);
is_diff_dst_reordered_4data
=
true
;
}
auto
diff_dst_memory_4data_p
=
handler
.
AcquireDiffDstMemoryFromDataPrimitive
(
user_diff_dst_memory_p
,
pipeline
);
// create mkldnn memory for output (i.e. diff src)
auto
diff_src_memory
=
memory
(
conv_bwd_data_pd
.
diff_src_primitive_desc
(),
auto
diff_src_memory_p
=
handler
.
AcquireDiffSrcMemoryFromDataPrimitive
(
reinterpret_cast
<
void
*>
(
input_grad_data
));
// create backward conv primitive for data
auto
conv_bwd_data_prim
=
conv_bwd_data
(
conv_bwd_data_pd
,
diff_dst_memory_4data
,
weights_memory
,
diff_src_memory
);
auto
conv_bwd_data_p
=
handler
.
AcquireConvolutionBackwardData
(
diff_dst_memory_4data_p
,
weights_memory_p
,
diff_src_memory_p
);
// push primitive and execute it
std
::
vector
<
primitive
>
pipeline
;
if
(
is_weights_reordered
)
pipeline
.
push_back
(
reorder_weights
);
if
(
is_diff_dst_reordered_4data
)
pipeline
.
push_back
(
reorder_diff_dst_4data
);
pipeline
.
push_back
(
conv_bwd_data_prim
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
pipeline
.
push_back
(
*
conv_bwd_data_p
);
input_grad
->
set_layout
(
DataLayout
::
kMKLDNN
);
input_grad
->
set_format
(
GetMKLDNNFormat
(
diff_src_memory
));
input_grad
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory_p
));
}
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
// Compute()
};
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部