Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3292f0ef
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3292f0ef
编写于
5月 18, 2020
作者:
J
Jacek Czaja
提交者:
GitHub
5月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[onednn] elementwise add broadcasting support (#24594)
上级
560c8153
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
37 addition
and
9 deletion
+37
-9
cmake/external/mkldnn.cmake
cmake/external/mkldnn.cmake
+1
-1
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+4
-2
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+15
-6
python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py
.../tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py
+17
-0
未找到文件。
cmake/external/mkldnn.cmake
浏览文件 @
3292f0ef
...
@@ -20,7 +20,7 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
...
@@ -20,7 +20,7 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
SET
(
MKLDNN_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/mkldnn
)
SET
(
MKLDNN_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/mkldnn
)
SET
(
MKLDNN_INC_DIR
"
${
MKLDNN_INSTALL_DIR
}
/include"
CACHE PATH
"mkldnn include directory."
FORCE
)
SET
(
MKLDNN_INC_DIR
"
${
MKLDNN_INSTALL_DIR
}
/include"
CACHE PATH
"mkldnn include directory."
FORCE
)
SET
(
MKLDNN_REPOSITORY https://github.com/intel/mkl-dnn.git
)
SET
(
MKLDNN_REPOSITORY https://github.com/intel/mkl-dnn.git
)
SET
(
MKLDNN_TAG
589c09728e34d09d79106cba0211e93caf142d54
)
SET
(
MKLDNN_TAG
fb95345126ade4c54f5507e580a5f5da8d30a515
)
# Introduce variables:
# Introduce variables:
# * CMAKE_INSTALL_LIBDIR
# * CMAKE_INSTALL_LIBDIR
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
3292f0ef
...
@@ -100,9 +100,11 @@ class ElementwiseOp : public framework::OperatorWithKernel {
...
@@ -100,9 +100,11 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation
auto
CanMKLDNNElementwiseAddBeUsed
=
[
&
]()
{
auto
CanMKLDNNElementwiseAddBeUsed
=
[
&
]()
{
return
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
();
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
rankdiff
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
()
-
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
().
size
();
return
(
axis
==
-
1
)
||
(
axis
==
rankdiff
);
};
};
if
(
platform
::
CanMKLDNNBeUsed
(
ctx
)
&&
if
(
platform
::
CanMKLDNNBeUsed
(
ctx
)
&&
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
3292f0ef
...
@@ -371,6 +371,13 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
...
@@ -371,6 +371,13 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
binary
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
binary
>
(
dev_ctx
,
engine
,
cpu_place
,
dev_ctx
,
engine
,
cpu_place
,
platform
::
CreateKey
(
framework
::
vectorize
(
x
->
dims
()),
uniq_name
))
{
platform
::
CreateKey
(
framework
::
vectorize
(
x
->
dims
()),
uniq_name
))
{
// bradcasting combined with in-place may require longer key
auto
rankdiff
=
x
->
dims
().
size
()
-
y
->
dims
().
size
();
if
(
rankdiff
>
0
)
{
this
->
key_
+=
std
::
to_string
(
rankdiff
);
this
->
key_common_
+=
std
::
to_string
(
rankdiff
);
}
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
x
->
layout
(),
DataLayout
::
kMKLDNN
,
x
->
layout
(),
DataLayout
::
kMKLDNN
,
...
@@ -390,17 +397,19 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
...
@@ -390,17 +397,19 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
const
auto
src_y_tz
=
framework
::
vectorize
(
y
->
dims
());
const
auto
src_y_tz
=
framework
::
vectorize
(
y
->
dims
());
const
auto
dst_tz
=
framework
::
vectorize
(
z
->
dims
());
const
auto
dst_tz
=
framework
::
vectorize
(
z
->
dims
());
// TODO(jczaja): Add function checking if data already exists
const
auto
src0_md
=
dnnl
::
memory
::
desc
(
const
auto
src0_md
=
dnnl
::
memory
::
desc
(
src_x_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
x
->
format
());
src_x_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
x
->
format
());
const
auto
src1_md
=
dnnl
::
memory
::
desc
(
auto
src1_md
=
dnnl
::
memory
::
desc
(
src_y_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
y
->
format
());
src_y_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
y
->
format
());
if
(
rankdiff
>
0
)
{
std
::
vector
<
int64_t
>
ones
(
rankdiff
,
1
);
std
::
vector
<
int64_t
>
dims1_ex
(
src_y_tz
);
dims1_ex
.
insert
(
dims1_ex
.
begin
(),
ones
.
begin
(),
ones
.
end
());
src1_md
=
src1_md
.
reshape
(
dims1_ex
);
}
const
auto
dst_md
=
memory
::
desc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
const
auto
dst_md
=
memory
::
desc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
any
);
MKLDNNMemoryFormat
::
any
);
// Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
// TODO(jczaja): Binary primitive support broadcasting, so we can support
// this in kernel
this
->
AcquireForwardPrimitiveDescriptor
(
dnnl
::
algorithm
::
binary_add
,
this
->
AcquireForwardPrimitiveDescriptor
(
dnnl
::
algorithm
::
binary_add
,
src0_md
,
src1_md
,
dst_md
);
src0_md
,
src1_md
,
dst_md
);
}
}
...
@@ -410,7 +419,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
...
@@ -410,7 +419,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
const
framework
::
Tensor
*
input
)
{
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
src_desc
(),
to_void_cast
<
T
>
(
input_data
),
"@src1_mem_p"
);
this
->
fwd_pd_
->
src
1
_desc
(),
to_void_cast
<
T
>
(
input_data
),
"@src1_mem_p"
);
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py
浏览文件 @
3292f0ef
...
@@ -49,5 +49,22 @@ class TestMKLDNNElementwiseAddOp3(TestMKLDNNElementwiseAddOp):
...
@@ -49,5 +49,22 @@ class TestMKLDNNElementwiseAddOp3(TestMKLDNNElementwiseAddOp):
self
.
out
=
np
.
add
(
self
.
x
,
self
.
y
)
self
.
out
=
np
.
add
(
self
.
x
,
self
.
y
)
class
TestMKLDNNElementwiseAddOp4
(
TestMKLDNNElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
uniform
(
1
,
2
,
[
2
,
3
,
4
,
32
]).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
uniform
(
1
,
2
,
[
4
,
32
]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
add
(
self
.
x
,
self
.
y
)
# TODO(jczaja): Enable when grad is ready
def
test_check_grad_normal
(
self
):
pass
def
test_check_grad_ingore_x
(
self
):
pass
def
test_check_grad_ingore_y
(
self
):
pass
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录