Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
155328a4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
155328a4
编写于
12月 07, 2018
作者:
Y
Yihua Xu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Clean Code
test=develop
上级
65dbc7cc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
53 addition
and
83 deletion
+53
-83
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+47
-77
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+6
-6
未找到文件。
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
155328a4
...
@@ -28,6 +28,46 @@ using mkldnn::stream;
...
@@ -28,6 +28,46 @@ using mkldnn::stream;
using
platform
::
to_void_cast
;
using
platform
::
to_void_cast
;
using
platform
::
GetMKLDNNFormat
;
using
platform
::
GetMKLDNNFormat
;
inline
void
GetWeightsTz
(
std
::
vector
<
int
>&
weights_tz
,
int
groups
,
// NOLINT
bool
is_conv3d
)
{
if
(
groups
>
1
)
{
if
(
is_conv3d
)
{
int
output
=
weights_tz
[
0
];
int
input
=
weights_tz
[
1
];
int
dimension
=
weights_tz
[
2
];
int
height
=
weights_tz
[
3
];
int
width
=
weights_tz
[
4
];
weights_tz
.
resize
(
6
);
weights_tz
[
0
]
=
groups
;
weights_tz
[
1
]
=
output
/
groups
;
weights_tz
[
2
]
=
input
;
weights_tz
[
3
]
=
dimension
;
weights_tz
[
4
]
=
height
;
weights_tz
[
5
]
=
width
;
}
else
{
int
output
=
weights_tz
[
0
];
int
input
=
weights_tz
[
1
];
int
height
=
weights_tz
[
2
];
int
width
=
weights_tz
[
3
];
weights_tz
.
resize
(
5
);
weights_tz
[
0
]
=
groups
;
weights_tz
[
1
]
=
output
/
groups
;
weights_tz
[
2
]
=
input
;
weights_tz
[
3
]
=
height
;
weights_tz
[
4
]
=
width
;
}
}
}
inline
mkldnn
::
memory
::
format
GetWeightsFormat
(
mkldnn
::
memory
::
format
format
,
int
groups
,
bool
is_conv3d
)
{
if
(
is_conv3d
)
{
return
(
groups
==
1
)
?
format
:
mkldnn
::
memory
::
format
::
goidhw
;
}
else
{
return
(
groups
==
1
)
?
format
:
mkldnn
::
memory
::
format
::
goihw
;
}
}
template
<
typename
T
>
template
<
typename
T
>
class
ConvMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
class
ConvMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -53,7 +93,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -53,7 +93,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
filter
->
format
()
!=
memory
::
format
::
format_undef
,
filter
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Filter tensor"
);
"Wrong layout/format set for Filter tensor"
);
PADDLE_ENFORCE
(
input
->
dims
().
size
()
==
4
||
input
->
dims
().
size
()
==
5
,
PADDLE_ENFORCE
(
input
->
dims
().
size
()
==
4
||
input
->
dims
().
size
()
==
5
,
"Input must be with 4 or 5dimensions, i.e. NCHW or NCDHW"
);
"Input must be with 4 or 5
dimensions, i.e. NCHW or NCDHW"
);
PADDLE_ENFORCE
(
filter
->
dims
().
size
()
==
4
||
filter
->
dims
().
size
()
==
5
,
PADDLE_ENFORCE
(
filter
->
dims
().
size
()
==
4
||
filter
->
dims
().
size
()
==
5
,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"
);
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"
);
if
(
bias
)
{
if
(
bias
)
{
...
@@ -87,33 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -87,33 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
vector
<
int
>
weights_tz
=
std
::
vector
<
int
>
weights_tz
=
paddle
::
framework
::
vectorize2int
(
filter
->
dims
());
paddle
::
framework
::
vectorize2int
(
filter
->
dims
());
int
g
=
std
::
max
(
groups
,
1
);
int
g
=
std
::
max
(
groups
,
1
);
if
(
g
>
1
)
{
GetWeightsTz
(
weights_tz
,
g
,
is_conv3d
);
if
(
is_conv3d
)
{
int
o
=
weights_tz
[
0
];
int
i
=
weights_tz
[
1
];
int
d
=
weights_tz
[
2
];
int
h
=
weights_tz
[
3
];
int
w
=
weights_tz
[
4
];
weights_tz
.
resize
(
6
);
weights_tz
[
0
]
=
g
;
weights_tz
[
1
]
=
o
/
g
;
weights_tz
[
2
]
=
i
;
weights_tz
[
3
]
=
d
;
weights_tz
[
4
]
=
h
;
weights_tz
[
5
]
=
w
;
}
else
{
int
o
=
weights_tz
[
0
];
int
i
=
weights_tz
[
1
];
int
h
=
weights_tz
[
2
];
int
w
=
weights_tz
[
3
];
weights_tz
.
resize
(
5
);
weights_tz
[
0
]
=
g
;
weights_tz
[
1
]
=
o
/
g
;
weights_tz
[
2
]
=
i
;
weights_tz
[
3
]
=
h
;
weights_tz
[
4
]
=
w
;
}
}
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
// Get unique name for storing MKLDNN primitives
// Get unique name for storing MKLDNN primitives
...
@@ -126,12 +140,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -126,12 +140,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
src_format
=
input
->
format
();
auto
src_format
=
input
->
format
();
mkldnn
::
memory
::
format
weights_format
=
mkldnn
::
memory
::
format
weights_format
=
(
g
==
1
)
?
filter
->
format
()
:
mkldnn
::
memory
::
format
::
goihw
;
GetWeightsFormat
(
filter
->
format
(),
g
,
is_conv3d
);
if
(
is_conv3d
)
{
weights_format
=
(
g
==
1
)
?
filter
->
format
()
:
mkldnn
::
memory
::
format
::
goidhw
;
}
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
...
@@ -146,15 +155,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -146,15 +155,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
chosen_memory_format
=
auto
chosen_memory_format
=
platform
::
data_format_to_memory_format
(
data_format
);
platform
::
data_format_to_memory_format
(
data_format
);
weights_format
=
(
g
==
1
)
?
chosen_memory_format
:
mkldnn
::
memory
::
format
::
goihw
;
if
(
is_conv3d
)
{
if
(
is_conv3d
)
{
chosen_memory_format
=
chosen_memory_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
chosen_memory_format
);
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
chosen_memory_format
);
weights_format
=
(
g
==
1
)
?
chosen_memory_format
:
mkldnn
::
memory
::
format
::
goidhw
;
}
}
weights_format
=
GetWeightsFormat
(
chosen_memory_format
,
g
,
is_conv3d
);
auto
src_md
=
platform
::
MKLDNNMemDesc
(
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
chosen_memory_format
);
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
chosen_memory_format
);
...
@@ -397,43 +402,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -397,43 +402,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std
::
vector
<
int
>
weights_tz
=
std
::
vector
<
int
>
weights_tz
=
paddle
::
framework
::
vectorize2int
(
filter
->
dims
());
paddle
::
framework
::
vectorize2int
(
filter
->
dims
());
int
g
=
std
::
max
(
groups
,
1
);
int
g
=
std
::
max
(
groups
,
1
);
if
(
g
>
1
)
{
GetWeightsTz
(
weights_tz
,
g
,
is_conv3d
);
if
(
is_conv3d
)
{
int
o
=
weights_tz
[
0
];
int
i
=
weights_tz
[
1
];
int
d
=
weights_tz
[
2
];
int
h
=
weights_tz
[
3
];
int
w
=
weights_tz
[
4
];
weights_tz
.
resize
(
6
);
weights_tz
[
0
]
=
g
;
weights_tz
[
1
]
=
o
/
g
;
weights_tz
[
2
]
=
i
;
weights_tz
[
3
]
=
d
;
weights_tz
[
4
]
=
h
;
weights_tz
[
5
]
=
w
;
}
else
{
int
o
=
weights_tz
[
0
];
int
i
=
weights_tz
[
1
];
int
h
=
weights_tz
[
2
];
int
w
=
weights_tz
[
3
];
weights_tz
.
resize
(
5
);
weights_tz
[
0
]
=
g
;
weights_tz
[
1
]
=
o
/
g
;
weights_tz
[
2
]
=
i
;
weights_tz
[
3
]
=
h
;
weights_tz
[
4
]
=
w
;
}
}
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
auto
src_format
=
input
->
format
();
auto
src_format
=
input
->
format
();
mkldnn
::
memory
::
format
weights_format
=
mkldnn
::
memory
::
format
weights_format
=
(
g
==
1
)
?
filter
->
format
()
:
mkldnn
::
memory
::
format
::
goihw
;
GetWeightsFormat
(
filter
->
format
(),
g
,
is_conv3d
);
if
(
is_conv3d
)
{
weights_format
=
(
g
==
1
)
?
filter
->
format
()
:
mkldnn
::
memory
::
format
::
goidhw
;
}
// Get an unique name from "argument" name of "Output" variable
// Get an unique name from "argument" name of "Output" variable
// as well as attributes of primitive to be created
// as well as attributes of primitive to be created
...
@@ -461,15 +435,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -461,15 +435,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
chosen_memory_format
=
auto
chosen_memory_format
=
platform
::
data_format_to_memory_format
(
data_format
);
platform
::
data_format_to_memory_format
(
data_format
);
weights_format
=
(
g
==
1
)
?
chosen_memory_format
:
mkldnn
::
memory
::
format
::
goihw
;
if
(
is_conv3d
)
{
if
(
is_conv3d
)
{
chosen_memory_format
=
chosen_memory_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
chosen_memory_format
);
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
chosen_memory_format
);
weights_format
=
(
g
==
1
)
?
chosen_memory_format
:
mkldnn
::
memory
::
format
::
goidhw
;
}
}
weights_format
=
GetWeightsFormat
(
chosen_memory_format
,
g
,
is_conv3d
);
auto
src_md
=
platform
::
MKLDNNMemDesc
(
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
chosen_memory_format
);
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
chosen_memory_format
);
...
...
paddle/fluid/operators/conv_op.cc
浏览文件 @
155328a4
...
@@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() {
...
@@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() {
"The format of output tensor is X (one-dimensional) of size equal"
"The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN."
)
"to the number of output channels. Only used with MKL-DNN."
)
.
AsDispensable
();
.
AsDispensable
();
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."
);
AddInput
(
"ResidualData"
,
AddInput
(
"ResidualData"
,
"(Tensor) Tensor with residual data "
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"to which convolution output will be added."
"Used with fuse_residual_connection fusion."
)
"Used with fuse_residual_connection fusion."
)
.
AsDispensable
();
.
AsDispensable
();
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector<int> default:{1, 1}), the "
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"strides(h_stride, w_stride) of "
...
@@ -251,14 +251,14 @@ void Conv3DOpMaker::Make() {
...
@@ -251,14 +251,14 @@ void Conv3DOpMaker::Make() {
"is the width of the filter."
"is the width of the filter."
"If the groups attribute is greater than 1, C equals the number of "
"If the groups attribute is greater than 1, C equals the number of "
"input image channels divided by the groups."
);
"input image channels divided by the groups."
);
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."
);
AddInput
(
"ResidualData"
,
AddInput
(
"ResidualData"
,
"(Tensor) Tensor with residual data "
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"to which convolution output will be added."
"Used with fuse_residual_connection fusion."
)
"Used with fuse_residual_connection fusion."
)
.
AsDispensable
();
.
AsDispensable
();
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector<int>, default:{1, 1, 1}), the "
"(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of "
"strides(d_stride, h_stride, w_stride) of "
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录