Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
84cc61b2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
84cc61b2
编写于
11月 03, 2020
作者:
J
Jacek Czaja
提交者:
GitHub
11月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[oneDNN] sum op refactor (#28318)
上级
6f0f45f6
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
126 addition
and
106 deletion
+126
-106
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
+124
-53
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+0
-53
python/paddle/fluid/tests/unittests/mkldnn/test_sum_mkldnn_op.py
...paddle/fluid/tests/unittests/mkldnn/test_sum_mkldnn_op.py
+2
-0
未找到文件。
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
浏览文件 @
84cc61b2
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/mkldnn_
helper
.h"
#include "paddle/fluid/platform/mkldnn_
reuse
.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -51,6 +51,95 @@ using paddle::platform::CPUDeviceContext;
...
@@ -51,6 +51,95 @@ using paddle::platform::CPUDeviceContext;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
platform
::
to_void_cast
;
using
platform
::
to_void_cast
;
template
<
typename
T
>
class
SumMKLDNNHandler
:
public
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
sum
>
{
public:
SumMKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
vector
<
framework
::
Variable
*>&
in_vars
,
framework
::
LoDTensor
*
z
,
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
sum
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
framework
::
vectorize
(
z
->
dims
()),
uniq_name
)),
num_inputs_
(
0
)
{
for
(
size_t
i
=
0
;
i
<
in_vars
.
size
();
i
++
)
{
srcs_suffix_
.
push_back
(
std
::
string
(
"-"
)
+
std
::
to_string
(
i
));
}
if
(
!
this
->
isCached
())
{
auto
dst_tz
=
framework
::
vectorize
<
int64_t
>
(
z
->
dims
());
auto
src_tz
=
dst_tz
;
std
::
vector
<
memory
::
desc
>
srcs_md
;
for
(
size_t
i
=
0
;
i
<
in_vars
.
size
();
i
++
)
{
auto
&
input_it
=
in_vars
[
i
]
->
Get
<
framework
::
LoDTensor
>
();
if
(
input_it
.
numel
()
==
0
)
{
continue
;
}
MKLDNNMemoryFormat
input_format
=
input_it
.
format
();
srcs_md
.
push_back
(
memory
::
desc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
));
++
num_inputs_
;
}
std
::
vector
<
float
>
scales
(
num_inputs_
,
1.0
);
auto
dst_md
=
memory
::
desc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
any
);
this
->
AcquireForwardPrimitiveDescriptor
(
dst_md
,
scales
,
srcs_md
);
}
}
// (jczaja) sum oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void
AcquireForwardPrimitiveDescriptor
(
const
memory
::
desc
&
dst_md
,
const
std
::
vector
<
float
>&
scales
,
const
std
::
vector
<
memory
::
desc
>&
srcs_md
)
{
// Sum op does not have backward so no passing from FWD to BWD is needed
const
std
::
string
key_pd
=
this
->
key_
+
"@fwd_pd"
;
this
->
fwd_pd_
=
std
::
static_pointer_cast
<
dnnl
::
sum
::
primitive_desc
>
(
this
->
dev_ctx_
.
GetBlob
(
key_pd
));
if
(
this
->
fwd_pd_
==
nullptr
)
{
this
->
fwd_pd_
.
reset
(
new
mkldnn
::
sum
::
primitive_desc
(
dst_md
,
scales
,
srcs_md
,
this
->
engine_
));
this
->
dev_ctx_
.
SetBlob
(
key_pd
,
this
->
fwd_pd_
);
}
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemory
(
const
framework
::
Tensor
&
input
,
int
i
)
{
const
T
*
input_data
=
input
.
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
src_desc
(
i
),
to_void_cast
<
T
>
(
input_data
),
"@src_mem_p"
+
srcs_suffix_
[
i
]);
}
using
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
sum
>::
AcquireDstMemory
;
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemory
(
void
)
{
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
dst_desc
(),
"@dst_mem_p"
);
}
inline
int
GetNumInputs
(
void
)
{
return
num_inputs_
;
}
protected:
// isCached need to be overloaded as base one works on key_common
bool
isCached
()
{
const
std
::
string
key_pd
=
this
->
key_
+
"@fwd_pd"
;
this
->
fwd_pd_
=
std
::
static_pointer_cast
<
dnnl
::
sum
::
primitive_desc
>
(
this
->
dev_ctx_
.
GetBlob
(
key_pd
));
const
std
::
string
key_p
=
this
->
key_
+
"@fwd_p"
;
return
(
this
->
dev_ctx_
.
GetBlob
(
key_p
)
!=
nullptr
);
}
private:
int
num_inputs_
;
std
::
vector
<
std
::
string
>
srcs_suffix_
;
};
template
<
typename
T
>
template
<
typename
T
>
class
SumMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
class
SumMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -59,85 +148,67 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -59,85 +148,67 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle
::
platform
::
errors
::
PreconditionNotMet
(
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"Operator DNNL Sum must use CPUPlace"
));
"Operator DNNL Sum must use CPUPlace"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
auto
in_vars
=
ctx
.
MultiInputVar
(
"X"
);
auto
in_vars
=
ctx
.
MultiInputVar
(
"X"
);
auto
out_var
=
ctx
.
OutputVar
(
"Out"
);
PADDLE_ENFORCE_NE
(
in_vars
.
empty
(),
true
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_NE
(
in_vars
.
empty
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Input variable is empty."
));
"Input variable is empty."
));
bool
in_place
=
out_var
==
in_vars
[
0
];
auto
&
input0
=
in_vars
[
0
]
->
Get
<
LoDTensor
>
();
LoDTensor
*
output
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
LoDTensor
*
output
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dst_tz
=
framework
::
vectorize
<
int64_t
>
(
output
->
dims
());
bool
in_place
=
(
input0
.
numel
()
>
0
)
&&
input0
.
IsSharedBufferWith
(
*
output
);
auto
src_tz
=
dst_tz
;
MKLDNNMemoryFormat
output_format
{
MKLDNNMemoryFormat
::
undef
};
std
::
vector
<
float
>
scales
;
std
::
vector
<
memory
::
desc
>
srcs_md
;
std
::
vector
<
mkldnn
::
memory
>
srcs_mem
;
auto
&
input0
=
in_vars
[
0
]
->
Get
<
LoDTensor
>
();
SumMKLDNNHandler
<
T
>
handler
(
dev_ctx
,
ctx
.
GetPlace
(),
in_vars
,
output
,
in_place
=
(
input0
.
numel
()
>
0
)
&&
(
input0
.
data
<
T
>
()
==
output_data
);
ctx
.
OutputName
(
"Out"
)
);
// Create list of SRC MEMs
std
::
vector
<
std
::
shared_ptr
<
mkldnn
::
memory
>>
srcs_mem
;
srcs_mem
.
reserve
(
handler
.
GetNumInputs
());
int
input_index
=
0
;
for
(
size_t
i
=
0
;
i
<
in_vars
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
in_vars
.
size
();
i
++
)
{
auto
&
input_it
=
in_vars
[
i
]
->
Get
<
LoDTensor
>
();
auto
&
input_it
=
in_vars
[
i
]
->
Get
<
framework
::
LoDTensor
>
();
if
(
input_it
.
numel
()
==
0
)
{
if
(
input_it
.
numel
()
==
0
)
{
continue
;
continue
;
}
}
srcs_mem
.
push_back
(
handler
.
AcquireSrcMemory
(
input_it
,
input_index
));
const
T
*
input_data
=
input_it
.
data
<
T
>
();
++
input_index
;
MKLDNNMemoryFormat
input_format
=
input_it
.
format
();
auto
src_md
=
memory
::
desc
(
src_tz
,
memory
::
data_type
::
f32
,
input_format
);
auto
src_mem
=
memory
(
src_md
,
mkldnn_engine
,
to_void_cast
(
input_data
));
srcs_md
.
push_back
(
src_md
);
srcs_mem
.
push_back
(
src_mem
);
scales
.
push_back
(
1.0
);
}
auto
dst_md
=
memory
::
desc
(
dst_tz
,
memory
::
data_type
::
f32
,
MKLDNNMemoryFormat
::
any
);
auto
sum_pd
=
sum
::
primitive_desc
(
dst_md
,
scales
,
srcs_md
,
mkldnn_engine
);
std
::
shared_ptr
<
memory
>
dst_mem
;
if
(
in_place
)
{
dst_mem
.
reset
(
new
memory
(
sum_pd
.
dst_desc
(),
mkldnn_engine
));
}
else
{
dst_mem
.
reset
(
new
memory
(
sum_pd
.
dst_desc
(),
mkldnn_engine
,
output_data
));
}
}
auto
sum_prim
=
mkldnn
::
sum
(
sum_pd
);
auto
dst_mem
=
in_place
?
handler
.
AcquireDstMemory
()
output_format
=
platform
::
GetMKLDNNFormat
(
sum_pd
.
dst_desc
()
);
:
handler
.
AcquireDstMemory
(
output
);
std
::
shared_ptr
<
mkldnn
::
reorder
>
reorder_p
;
auto
sum_p
=
handler
.
AcquireForwardPrimitive
();
std
::
shared_ptr
<
memory
>
target_mem
;
if
(
in_place
)
{
output_format
=
input0
.
format
();
target_mem
.
reset
(
new
memory
({{
src_tz
},
memory
::
data_type
::
f32
,
output_format
},
mkldnn_engine
,
output_data
));
reorder_p
=
std
::
make_shared
<
reorder
>
(
*
dst_mem
,
*
target_mem
);
}
mkldnn
::
stream
astream
(
mkldnn_engine
);
std
::
unordered_map
<
int
,
memory
>
args
;
std
::
unordered_map
<
int
,
memory
>
args
;
for
(
size_t
i
=
0
;
i
<
srcs_mem
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
srcs_mem
.
size
();
++
i
)
{
args
.
insert
({
MKLDNN_ARG_MULTIPLE_SRC
+
i
,
srcs_mem
.
at
(
i
)});
args
.
insert
({
MKLDNN_ARG_MULTIPLE_SRC
+
i
,
*
(
srcs_mem
[
i
]
)});
}
}
args
.
insert
({
MKLDNN_ARG_DST
,
*
dst_mem
});
args
.
insert
({
MKLDNN_ARG_DST
,
*
dst_mem
});
sum_prim
.
execute
(
astream
,
args
);
mkldnn
::
stream
astream
(
dev_ctx
.
GetEngine
());
sum_p
->
execute
(
astream
,
args
);
astream
.
wait
();
astream
.
wait
();
// For in-place execution which sum does not have we need to fake it
// so from oneDNN dst memory we reorder data into input
if
(
in_place
)
{
if
(
in_place
)
{
const
std
::
string
reorder_key
=
platform
::
CreateKey
(
framework
::
vectorize
(
output
->
dims
()),
ctx
.
OutputName
(
"Out"
)
+
"-I"
);
auto
&
in_out
=
in_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
();
auto
output_tz
=
framework
::
vectorize
<
int64_t
>
(
output
->
dims
());
platform
::
ReorderMKLDNNHandler
reorder_handler
(
output_tz
,
output
->
type
(),
framework
::
ToMKLDNNDataType
(
in_out
.
type
()),
dev_ctx
,
dev_ctx
.
GetEngine
(),
reorder_key
);
auto
target_mem
=
reorder_handler
.
AcquireDstMemory
(
output
,
in_out
.
format
(),
ctx
.
GetPlace
());
auto
reorder_p
=
reorder_handler
.
AcquireReorder
(
target_mem
,
dst_mem
);
reorder_p
->
execute
(
astream
,
*
dst_mem
,
*
target_mem
);
reorder_p
->
execute
(
astream
,
*
dst_mem
,
*
target_mem
);
astream
.
wait
();
astream
.
wait
();
}
}
output
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_mem
));
output
->
set_format
(
output_format
);
}
}
};
};
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
84cc61b2
...
@@ -591,59 +591,6 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
...
@@ -591,59 +591,6 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
}
}
};
};
class
SumMKLDNNHandler
:
public
MKLDNNHandler
{
public:
SumMKLDNNHandler
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
)
{}
std
::
shared_ptr
<
mkldnn
::
sum
::
primitive_desc
>
AcquireSumPrimitiveDescriptor
(
const
std
::
vector
<
std
::
shared_ptr
<
mkldnn
::
memory
>>&
src_mems
,
const
std
::
vector
<
float
>&
scales
,
const
mkldnn
::
memory
::
desc
&
dst_md
)
{
const
std
::
string
key_sum_pd
=
key_
+
"@sum_pd"
;
sum_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
sum
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_sum_pd
));
if
(
sum_pd_
==
nullptr
)
{
// Get vector of inputs primitive descriptors
std
::
vector
<
mkldnn
::
memory
::
desc
>
src_ds
;
for
(
auto
&
input_mem
:
src_mems
)
{
src_ds
.
push_back
(
input_mem
->
get_desc
());
}
sum_pd_
.
reset
(
new
mkldnn
::
sum
::
primitive_desc
(
dst_md
,
scales
,
src_ds
,
engine_
));
dev_ctx_
.
SetBlob
(
key_sum_pd
,
sum_pd_
);
}
return
sum_pd_
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
sum_pd_
->
dst_desc
(),
ptr
,
"@dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSecondSrcMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_src2_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
sum
>
AcquireSum
()
{
auto
prim_key
=
key_
+
"@sum_p"
;
auto
sum_p
=
std
::
static_pointer_cast
<
mkldnn
::
sum
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
sum_p
==
nullptr
)
{
sum_p
=
std
::
make_shared
<
mkldnn
::
sum
>
(
*
sum_pd_
);
dev_ctx_
.
SetBlob
(
prim_key
,
sum_p
);
}
return
sum_p
;
}
private:
std
::
shared_ptr
<
mkldnn
::
sum
::
primitive_desc
>
sum_pd_
;
};
template
<
typename
T
>
template
<
typename
T
>
class
ActivationMKLDNNHandler
class
ActivationMKLDNNHandler
:
public
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
:
public
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_sum_mkldnn_op.py
浏览文件 @
84cc61b2
...
@@ -86,4 +86,6 @@ class TestMKLDNNSumInplaceOp(unittest.TestCase):
...
@@ -86,4 +86,6 @@ class TestMKLDNNSumInplaceOp(unittest.TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
paddle
import
enable_static
enable_static
()
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录