Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f2a88042
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f2a88042
编写于
12月 04, 2018
作者:
M
Michal Gallus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix style @ concat integration and tests
test=develop
上级
738069e4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
35 addition
and
33 deletion
+35
-33
paddle/fluid/operators/concat_mkldnn_op.cc
paddle/fluid/operators/concat_mkldnn_op.cc
+16
-17
paddle/fluid/operators/concat_op.cc
paddle/fluid/operators/concat_op.cc
+17
-16
python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py
python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py
+2
-0
未找到文件。
paddle/fluid/operators/concat_mkldnn_op.cc
浏览文件 @
f2a88042
...
...
@@ -30,15 +30,15 @@ using platform::to_void_cast;
static
void
EnforceLayouts
(
const
std
::
vector
<
const
Tensor
*>
inputs
)
{
for
(
auto
*
input
:
inputs
)
{
const
bool
is_layout_correct
=
input
->
layout
()
==
DataLayout
::
kMKLDNN
;
const
bool
is_format_defined
=
input
->
format
()
!=
memory
::
format
::
format_undef
;
const
bool
is_format_defined
=
input
->
format
()
!=
memory
::
format
::
format_undef
;
PADDLE_ENFORCE
(
is_layout_correct
&&
is_format_defined
,
"Wrong layout/format set for Input tensor"
);
}
}
static
memory
::
primitive_desc
CreateMemPrimDesc
(
const
Tensor
&
input
,
const
mkldnn
::
engine
&
engine
)
{
static
memory
::
primitive_desc
CreateMemPrimDesc
(
const
Tensor
&
input
,
const
mkldnn
::
engine
&
engine
)
{
constexpr
auto
data_type
=
mkldnn
::
memory
::
f32
;
const
auto
dims
=
paddle
::
framework
::
vectorize2int
(
input
.
dims
());
const
auto
format
=
input
.
format
();
...
...
@@ -48,8 +48,8 @@ static memory::primitive_desc CreateMemPrimDesc(
}
static
mkldnn
::
memory
::
format
GetDstMemFormat
(
const
concat
::
primitive_desc
&
concat_pd
)
{
return
(
memory
::
format
)
concat_pd
.
dst_primitive_desc
().
desc
().
data
.
format
;
const
concat
::
primitive_desc
&
concat_pd
)
{
return
(
memory
::
format
)
concat_pd
.
dst_primitive_desc
().
desc
().
data
.
format
;
}
static
platform
::
CPUPlace
GetCpuPlace
(
...
...
@@ -61,10 +61,9 @@ static platform::CPUPlace GetCpuPlace(
}
static
const
mkldnn
::
engine
&
GetMKLDNNEngine
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
MKLDNNDeviceContext
>();
return
dev_ctx
.
GetEngine
();
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
MKLDNNDeviceContext
>();
return
dev_ctx
.
GetEngine
();
}
template
<
typename
T
>
...
...
@@ -89,7 +88,7 @@ class ConcatPrimitiveFactory {
memory
::
desc
CreateDstMemDescriptor
(
Tensor
*
output
)
{
auto
dst_dims
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
return
memory
::
desc
(
dst_dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
memory
::
format
::
any
);
memory
::
format
::
any
);
}
mkldnn
::
memory
CreateDstMemory
(
const
concat
::
primitive_desc
&
concat_pd
,
...
...
@@ -101,10 +100,10 @@ class ConcatPrimitiveFactory {
void
CreateSourcesDescriptors
(
const
std
::
vector
<
const
Tensor
*>
multi_input
,
const
mkldnn
::
engine
&
mkldnn_engine
)
{
for
(
size_t
i
=
0
;
i
<
multi_input
.
size
();
i
++
)
{
auto
mem_prim_desc
=
CreateMemPrimDesc
(
*
multi_input
[
i
],
mkldnn_engine
);
srcs_pd
.
push_back
(
mem_prim_desc
);
srcs
.
push_back
(
memory
(
mem_prim_desc
,
to_void_cast
(
multi_input
[
i
]
->
data
<
T
>
())));
auto
mem_prim_desc
=
CreateMemPrimDesc
(
*
multi_input
[
i
],
mkldnn_engine
);
srcs_pd
.
push_back
(
mem_prim_desc
);
srcs
.
push_back
(
memory
(
mem_prim_desc
,
to_void_cast
(
multi_input
[
i
]
->
data
<
T
>
())));
}
}
...
...
@@ -134,8 +133,8 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int64_t
concat_axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
ConcatPrimitiveFactory
<
T
>
prim_creator
;
auto
concat_pd
=
prim_creator
.
CreateConcatPrimDescriptor
(
multi_input
,
output
,
static_cast
<
int
>
(
concat_axis
),
mkldnn_engine
);
auto
concat_pd
=
prim_creator
.
CreateConcatPrimDescriptor
(
multi_input
,
output
,
static_cast
<
int
>
(
concat_axis
),
mkldnn_engine
);
auto
concat
=
prim_creator
.
CreateConcatPrimitive
(
concat_pd
,
output
,
place
);
stream
(
stream
::
kind
::
eager
).
submit
({
concat
}).
wait
();
...
...
paddle/fluid/operators/concat_op.cc
浏览文件 @
f2a88042
...
...
@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
#include <paddle/fluid/platform/mkldnn_helper.h>
#include <string>
#include <vector>
#include <paddle/fluid/platform/mkldnn_helper.h>
namespace
paddle
{
namespace
operators
{
...
...
@@ -63,18 +63,19 @@ class ConcatOp : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
MultiInputVar
(
"X"
)[
0
]);
#ifdef PADDLE_WITH_MKLDNN
if
(
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
#endif
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
MultiInputVar
(
"X"
)[
0
]);
#ifdef PADDLE_WITH_MKLDNN
if
(
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
#endif
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
class
ConcatOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -82,9 +83,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
override
{
AddInput
(
"X"
,
"Input tensors of concat operator."
).
AsDuplicable
();
AddOutput
(
"Out"
,
"Output tensor of concat operator."
);
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Indicates if MKL-DNN kernel will be used"
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Indicates if MKL-DNN kernel will be used"
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"axis"
,
"The axis along which the input tensors will be concatenated."
)
.
SetDefault
(
0
);
...
...
@@ -101,7 +103,6 @@ Examples:
[5,6]]
)DOC"
);
}
};
...
...
python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py
浏览文件 @
f2a88042
...
...
@@ -29,6 +29,7 @@ class TestMKLDNNConcatOp(TestConcatOp):
def
init_kernel_type
(
self
):
self
.
use_mkldnn
=
True
class
TestMKLDNNConcatOp2
(
TestConcatOp2
):
def
setUp
(
self
):
super
(
TestMKLDNNConcatOp2
,
self
).
setUp
()
...
...
@@ -40,6 +41,7 @@ class TestMKLDNNConcatOp2(TestConcatOp2):
def
init_kernel_type
(
self
):
self
.
use_mkldnn
=
True
class
TestMKLDNNConcatOp3
(
TestConcatOp3
):
def
setUp
(
self
):
super
(
TestMKLDNNConcatOp3
,
self
).
setUp
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录