Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b2727020
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b2727020
编写于
8月 04, 2022
作者:
J
jakpiase
提交者:
GitHub
8月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added conv and conv_tranpose support for md (#44677)
上级
6506668e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
84 deletion
+29
-84
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+27
-65
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
+2
-19
未找到文件。
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
b2727020
...
...
@@ -24,13 +24,13 @@ namespace paddle {
namespace
operators
{
namespace
{
inline
MKLDNNMemoryFormat
GetWeightsFormat
(
const
MKLDNNMemoryFormat
format
,
const
int
groups
,
inline
MKLDNNMemoryFormat
GetWeightsFormat
(
const
int
groups
,
const
bool
is_conv3d
)
{
if
(
is_conv3d
)
{
return
(
groups
==
1
)
?
format
:
MKLDNNMemoryFormat
::
goidhw
;
return
(
groups
==
1
)
?
MKLDNNMemoryFormat
::
oidhw
:
MKLDNNMemoryFormat
::
goidhw
;
}
else
{
return
(
groups
==
1
)
?
format
:
MKLDNNMemoryFormat
::
goihw
;
return
(
groups
==
1
)
?
MKLDNNMemoryFormat
::
oihw
:
MKLDNNMemoryFormat
::
goihw
;
}
}
...
...
@@ -98,10 +98,6 @@ class ConvMKLDNNHandlerT
"The input tensor's layout should be %d, but got %d."
,
framework
::
DataLayout
::
kMKLDNN
,
input
->
layout
()));
PADDLE_ENFORCE_NE
(
input
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Input tensor"
));
PADDLE_ENFORCE_EQ
(
filter
->
layout
(),
...
...
@@ -110,10 +106,6 @@ class ConvMKLDNNHandlerT
"The Filter tensor's layout should be %d, but got %d."
,
framework
::
DataLayout
::
kMKLDNN
,
filter
->
layout
()));
PADDLE_ENFORCE_NE
(
filter
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Filter tensor"
));
PADDLE_ENFORCE_GE
(
input
->
dims
().
size
(),
...
...
@@ -153,10 +145,6 @@ class ConvMKLDNNHandlerT
"The Bias tensor's layout should be %d, but got %d."
,
framework
::
DataLayout
::
kMKLDNN
,
bias
->
layout
()));
PADDLE_ENFORCE_NE
(
bias
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Got wrong format for Bias tensor."
));
PADDLE_ENFORCE_EQ
(
bias
->
dims
().
size
(),
1
,
...
...
@@ -307,10 +295,6 @@ class ConvMKLDNNHandlerT
"The input tensor's layout should be %d, but got %d."
,
framework
::
DataLayout
::
kMKLDNN
,
in
->
layout
()));
PADDLE_ENFORCE_NE
(
in
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Got wrong format for Input tensor."
));
PADDLE_ENFORCE_EQ
(
filter
->
layout
(),
...
...
@@ -319,10 +303,6 @@ class ConvMKLDNNHandlerT
"The filter tensor's layout should be %d, but got %d."
,
framework
::
DataLayout
::
kMKLDNN
,
filter
->
layout
()));
PADDLE_ENFORCE_NE
(
filter
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Got wrong format for Filter tensor."
));
PADDLE_ENFORCE_EQ
(
out_grad
->
layout
(),
...
...
@@ -331,10 +311,6 @@ class ConvMKLDNNHandlerT
"The output_grad tensor's layout should be %d, but got %d."
,
framework
::
DataLayout
::
kMKLDNN
,
out_grad
->
layout
()));
PADDLE_ENFORCE_NE
(
out_grad
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for output_grad tensor"
));
PADDLE_ENFORCE_EQ
(
ctx
.
Attr
<
bool
>
(
"is_test"
),
...
...
@@ -596,10 +572,10 @@ class ConvMKLDNNHandlerT
auto
weights_tz
=
phi
::
vectorize
(
filter
->
dims
());
platform
::
GetGroupConvWeightsTz
(
weights_tz
,
groups
);
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
K
>
(),
GetWeightsFormat
(
filter
->
format
(),
groups
,
is_conv3d
));
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
K
>
(),
GetWeightsFormat
(
groups
,
is_conv3d
));
return
this
->
AcquireMemoryWithReorder
(
user_src_md
,
...
...
@@ -660,12 +636,11 @@ class ConvMKLDNNHandlerT
auto
user_mem_p
=
this
->
AcquireMemory
(
user_key_suffix
);
if
(
!
user_mem_p
)
{
auto
user_mem_md
=
platform
::
MKLDNNMemDesc
(
phi
::
vectorize
(
in_mem
->
dims
()),
platform
::
MKLDNNGetDataType
<
T
>
(),
in_mem
->
format
());
return
this
->
AcquireMemoryWithReorder
(
user_mem_md
,
mem_md
,
platform
::
to_void_cast
<
T
>
(
in_mem_data
),
key_mem
);
in_mem
->
mem_desc
(),
mem_md
,
platform
::
to_void_cast
<
T
>
(
in_mem_data
),
key_mem
);
}
else
{
const
std
::
string
target_key_suffix
{
key_mem_target
};
const
auto
target_mem_p
=
this
->
AcquireMemory
(
target_key_suffix
);
...
...
@@ -694,10 +669,10 @@ class ConvMKLDNNHandlerT
auto
weights_tz
=
phi
::
vectorize
(
filter
->
dims
());
platform
::
GetGroupConvWeightsTz
(
weights_tz
,
groups
);
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
K
>
(),
GetWeightsFormat
(
filter
->
format
(),
groups
,
is_conv3d
));
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
K
>
(),
GetWeightsFormat
(
groups
,
is_conv3d
));
return
this
->
AcquireMemoryWithReorder
(
user_src_md
,
...
...
@@ -713,10 +688,10 @@ class ConvMKLDNNHandlerT
auto
weights_tz
=
phi
::
vectorize
(
filter
->
dims
());
platform
::
GetGroupConvWeightsTz
(
weights_tz
,
groups
);
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
GetWeightsFormat
(
filter
->
format
(),
groups
,
is_conv3d
));
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
GetWeightsFormat
(
groups
,
is_conv3d
));
return
this
->
AcquireMemoryWithReorder
(
user_src_md
,
...
...
@@ -747,13 +722,9 @@ class ConvMKLDNNHandlerT
LOG
(
ERROR
)
<<
"Bias should be of type int32 but is "
<<
bias
->
dtype
();
}
const
K_Bias
*
bias_data
=
bias
->
data
<
K_Bias
>
();
auto
user_bias_md
=
platform
::
MKLDNNMemDesc
(
phi
::
vectorize
(
bias
->
dims
()),
platform
::
MKLDNNGetDataType
<
K_Bias
>
(),
MKLDNNMemoryFormat
::
x
);
return
this
->
AcquireMemoryWithReorder
(
user_bias_md
,
bias
->
mem_desc
()
,
this
->
fwd_pd_
->
bias_desc
(),
platform
::
to_void_cast
<
K_Bias
>
(
bias_data
),
"@bias_mem_p"
,
...
...
@@ -776,22 +747,16 @@ class ConvMKLDNNHandlerT
residual_mem_p
->
set_data_handle
(
residual_data
);
return
residual_mem_p
;
}
else
{
auto
user_residual_md
=
platform
::
MKLDNNMemDesc
(
phi
::
vectorize
(
residual_param
->
dims
()),
framework
::
ToMKLDNNDataType
(
framework
::
TransToProtoVarType
(
residual_param
->
dtype
())),
residual_param
->
format
());
return
this
->
AcquireMemoryFromPrimitive
(
user_residual_md
,
residual_data
,
"@user_residual_data_mem_p"
);
return
this
->
AcquireMemoryFromPrimitive
(
residual_param
->
mem_desc
(),
residual_data
,
"@user_residual_data_mem_p"
);
}
}
std
::
shared_ptr
<
dnnl
::
memory
>
AcquireDstMemoryWithResidual
(
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
residual_param
)
{
std
::
shared_ptr
<
dnnl
::
memory
>
dst_memory_p
;
if
(
residual_param
->
format
()
!=
platform
::
GetMKLDNNFormat
(
this
->
fwd_pd_
->
dst_desc
()))
{
if
(
residual_param
->
mem_desc
()
!=
this
->
fwd_pd_
->
dst_desc
())
{
auto
residual_memory_p
=
this
->
AcquireResidualMemory
(
residual_param
);
dst_memory_p
=
this
->
template
AcquireDstMemory
<
T_out
>(
output
);
this
->
AcquireReorder
(
residual_memory_p
,
dst_memory_p
);
...
...
@@ -903,8 +868,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
conv_p
->
execute
(
astream
,
args
);
astream
.
wait
();
output
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
output
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory_p
));
output
->
set_mem_desc
(
dst_memory_p
->
get_desc
());
}
template
<
typename
T_out
>
...
...
@@ -1018,8 +982,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
output
->
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
}
output
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
output
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory_p
));
output
->
set_mem_desc
(
dst_memory_p
->
get_desc
());
}
};
...
...
@@ -1078,7 +1041,6 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
auto
conv_bwd_weights_p
=
handler
.
AcquireBackwardWeightsPrimitive
();
// TODO(grygielski) why no bias_diff?
conv_bwd_weights_p
->
execute
(
astream
,
{{
DNNL_ARG_SRC
,
*
src_memory_p
},
...
...
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
浏览文件 @
b2727020
...
...
@@ -59,11 +59,6 @@ class ConvTransposeMKLDNNHandlerT
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"Got wrong layout = %d for Input tensor."
,
input
->
layout
()));
PADDLE_ENFORCE_NE
(
input
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Got wrong format for Input tensor. The input "
"format is undefined."
));
PADDLE_ENFORCE_EQ
(
filter
->
layout
(),
...
...
@@ -72,10 +67,6 @@ class ConvTransposeMKLDNNHandlerT
"The filter tensor's layout should be %d, but got %d."
,
DataLayout
::
kMKLDNN
,
filter
->
layout
()));
PADDLE_ENFORCE_NE
(
filter
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Got wrong formats for Filter tensor."
));
PADDLE_ENFORCE_EQ
(
input
->
dims
().
size
(),
...
...
@@ -98,10 +89,6 @@ class ConvTransposeMKLDNNHandlerT
"The bias tensor's laytout should be %d, but got %d."
,
DataLayout
::
kMKLDNN
,
bias
->
layout
()));
PADDLE_ENFORCE_NE
(
bias
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Got wrong format for Bias tensor."
));
PADDLE_ENFORCE_EQ
(
bias
->
dims
().
size
(),
...
...
@@ -233,11 +220,8 @@ class ConvTransposeMKLDNNHandlerT
std
::
shared_ptr
<
dnnl
::
memory
>
AcquireSrcMemoryWithReorder
(
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
phi
::
vectorize
(
input
->
dims
()),
platform
::
MKLDNNGetDataType
<
T
>
(),
input
->
format
());
return
platform
::
MKLDNNHandlerNoCachingT
<
T
,
dnnl
::
deconvolution_forward
>::
AcquireMemoryWithReorder
(
user_src_md
,
AcquireMemoryWithReorder
(
input
->
mem_desc
()
,
this
->
fwd_pd_
->
src_desc
(),
platform
::
to_void_cast
<
T
>
(
input_data
));
}
...
...
@@ -427,8 +411,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
auto
&
astream
=
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
conv_p
->
execute
(
astream
,
args
);
astream
.
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory_p
));
output
->
set_mem_desc
(
dst_memory_p
->
get_desc
());
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录