Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
657abd51
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
657abd51
编写于
5月 25, 2022
作者:
J
jakpiase
提交者:
GitHub
5月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
OneDNN md-in-tensor refactoring part 4: Memory descriptor enabled for more ops (#42946)
* added support for md in more ops * fixed typo
上级
c6f98fa0
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
29 addition
and
55 deletion
+29
-55
paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc
paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc
+4
-2
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
+3
-5
paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc
paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc
+5
-8
paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc
paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc
+4
-6
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
+13
-34
未找到文件。
paddle/fluid/operators/mkldnn/fill_constant_mkldnn_op.cc
浏览文件 @
657abd51
...
@@ -79,8 +79,10 @@ class FillConstantMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -79,8 +79,10 @@ class FillConstantMKLDNNKernel : public framework::OpKernel<T> {
{
DNNL_ARG_DST
,
*
src0_memory_p
}});
{
DNNL_ARG_DST
,
*
src0_memory_p
}});
astream
.
wait
();
astream
.
wait
();
out
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
// src0_memory_p's md was just to allow the usage of a binary
out
->
set_format
(
platform
::
GetPlainMKLDNNFormat
(
out
->
dims
().
size
()));
// primitive as a memset, and now we need to create a real one
out
->
set_mem_desc
({
phi
::
vectorize
(
shape
),
platform
::
MKLDNNGetDataType
<
T
>
(),
platform
::
GetPlainMKLDNNFormat
(
shape
.
size
())});
}
}
T
CalculateFillValue
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
T
CalculateFillValue
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
...
...
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
浏览文件 @
657abd51
...
@@ -124,7 +124,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -124,7 +124,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
&
astream
=
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
auto
&
astream
=
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
if
(
!
workspace_memory
->
get_desc
().
is_zero
())
{
if
(
!
workspace_memory
->
get_desc
().
is_zero
())
{
mid
->
set_
format
(
platform
::
GetMKLDNNFormat
(
*
workspace_memory
));
mid
->
set_
mem_desc
(
workspace_memory
->
get_desc
(
));
lrn_p
->
execute
(
astream
,
{{
DNNL_ARG_SRC
,
*
src_memory
},
lrn_p
->
execute
(
astream
,
{{
DNNL_ARG_SRC
,
*
src_memory
},
{
DNNL_ARG_DST
,
*
dst_memory
},
{
DNNL_ARG_DST
,
*
dst_memory
},
{
DNNL_ARG_WORKSPACE
,
*
workspace_memory
}});
{
DNNL_ARG_WORKSPACE
,
*
workspace_memory
}});
...
@@ -134,8 +134,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -134,8 +134,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
}
astream
.
wait
();
astream
.
wait
();
out
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
out
->
set_mem_desc
(
dst_memory
->
get_desc
());
out
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory
));
}
}
};
};
...
@@ -177,8 +176,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -177,8 +176,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
{
DNNL_ARG_WORKSPACE
,
*
workspace
}});
{
DNNL_ARG_WORKSPACE
,
*
workspace
}});
astream
.
wait
();
astream
.
wait
();
in_x_grad
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
in_x_grad
->
set_mem_desc
(
diff_src_memory
->
get_desc
());
in_x_grad
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
diff_src_memory
));
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc
浏览文件 @
657abd51
...
@@ -175,19 +175,17 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -175,19 +175,17 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
dnnl
::
memory
::
data_type
dout_type
=
framework
::
ToMKLDNNDataType
(
dnnl
::
memory
::
data_type
dout_type
=
framework
::
ToMKLDNNDataType
(
framework
::
TransToProtoVarType
(
dout
->
dtype
()));
framework
::
TransToProtoVarType
(
dout
->
dtype
()));
dnnl
::
memory
::
desc
md
(
dout_vec_dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
dout
->
format
());
dnnl
::
memory
::
format_tag
reorder_format_tag
=
platform
::
GetMKLDNNFormat
(
md
.
reshape
(
slice_dims
));
platform
::
ReorderMKLDNNHandler
reorder_handler
(
platform
::
ReorderMKLDNNHandler
reorder_handler
(
slice_dims
,
framework
::
TransToProtoVarType
(
dout
->
dtype
()),
dout_type
,
slice_dims
,
framework
::
TransToProtoVarType
(
dout
->
dtype
()),
dout_type
,
onednn_engine
);
onednn_engine
);
auto
reorder_src_memory_p
=
reorder_handler
.
AcquireSrcMemory
(
auto
reorder_src_memory_p
=
reorder_handler
.
AcquireSrcMemory
(
reorder_format_tag
,
platform
::
to_void_cast
(
dout
->
data
<
T
>
()));
dout
->
mem_desc
().
reshape
(
slice_dims
),
platform
::
to_void_cast
(
dout
->
data
<
T
>
()));
auto
reorder_dst_memory_p
=
reorder_handler
.
AcquireDstMemory
(
auto
reorder_dst_memory_p
=
reorder_handler
.
AcquireDstMemory
(
dx
,
dx_vec_dims
,
reorder_format_tag
,
ctx
.
GetPlace
());
dx
,
dx_vec_dims
,
platform
::
GetPlainMKLDNNFormat
(
dx_vec_dims
.
size
()),
ctx
.
GetPlace
());
memset
(
dx
->
data
<
T
>
(),
0
,
reorder_dst_memory_p
->
get_desc
().
get_size
());
memset
(
dx
->
data
<
T
>
(),
0
,
reorder_dst_memory_p
->
get_desc
().
get_size
());
auto
slice_mem_p
=
reorder_handler
.
AcquireSubmemory
(
slice_dims
,
offsets
,
auto
slice_mem_p
=
reorder_handler
.
AcquireSubmemory
(
slice_dims
,
offsets
,
...
@@ -199,8 +197,7 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -199,8 +197,7 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
reorder_p
->
execute
(
astream
,
*
reorder_src_memory_p
,
*
slice_mem_p
);
reorder_p
->
execute
(
astream
,
*
reorder_src_memory_p
,
*
slice_mem_p
);
astream
.
wait
();
astream
.
wait
();
dx
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
dx
->
set_mem_desc
(
reorder_dst_memory_p
->
get_desc
());
dx
->
set_format
(
reorder_format_tag
);
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/mkldnn/stack_mkldnn_op.cc
浏览文件 @
657abd51
...
@@ -59,7 +59,7 @@ class StackMKLDNNHandler
...
@@ -59,7 +59,7 @@ class StackMKLDNNHandler
// wrong output format deduction and suboptimal performance as a result
// wrong output format deduction and suboptimal performance as a result
if
(
stack_axis
!=
ndims
)
{
if
(
stack_axis
!=
ndims
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
srcs_md
.
emplace_back
(
memory
::
desc
(
input_dims
,
dt
,
inputs
[
i
]
->
format
()
));
srcs_md
.
push_back
(
inputs
[
i
]
->
mem_desc
(
));
}
}
input_dims
[
stack_axis
]
*=
inputs
.
size
();
input_dims
[
stack_axis
]
*=
inputs
.
size
();
...
@@ -69,8 +69,7 @@ class StackMKLDNNHandler
...
@@ -69,8 +69,7 @@ class StackMKLDNNHandler
extended_input_dims
[
stack_axis
]
=
1
;
extended_input_dims
[
stack_axis
]
=
1
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
srcs_md
.
emplace_back
(
memory
::
desc
(
input_dims
,
dt
,
inputs
[
i
]
->
format
())
srcs_md
.
push_back
(
inputs
[
i
]
->
mem_desc
().
reshape
(
extended_input_dims
));
.
reshape
(
extended_input_dims
));
}
}
// concat primitive choses suboptimal format tag because it cannot
// concat primitive choses suboptimal format tag because it cannot
...
@@ -130,9 +129,8 @@ class StackMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -130,9 +129,8 @@ class StackMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
concat_p
->
execute
(
astream
,
args
);
concat_p
->
execute
(
astream
,
args
);
astream
.
wait
();
astream
.
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_mem_desc
(
output
->
set_format
(
platform
::
GetMKLDNNFormat
(
dst_mem
->
get_desc
().
reshape
(
phi
::
vectorize
(
output
->
dims
())));
dst_mem
->
get_desc
().
reshape
(
phi
::
vectorize
(
output
->
dims
()))));
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
浏览文件 @
657abd51
...
@@ -60,17 +60,16 @@ class SumMKLDNNHandler
...
@@ -60,17 +60,16 @@ class SumMKLDNNHandler
auto
src_tz
=
dst_tz
;
auto
src_tz
=
dst_tz
;
std
::
vector
<
dnnl
::
memory
::
desc
>
srcs_md
;
std
::
vector
<
dnnl
::
memory
::
desc
>
srcs_md
;
srcs_md
.
reserve
(
in_vars
.
size
());
for
(
size_t
i
=
0
;
i
<
in_vars
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
in_vars
.
size
();
i
++
)
{
auto
&
input_it
=
in_vars
[
i
]
->
Get
<
framework
::
LoDTensor
>
();
auto
&
input_it
=
in_vars
[
i
]
->
Get
<
framework
::
LoDTensor
>
();
if
(
input_it
.
numel
()
==
0
)
{
if
(
input_it
.
numel
()
==
0
)
{
continue
;
continue
;
}
}
MKLDNNMemoryFormat
input_format
=
input_it
.
format
();
srcs_md
.
push_back
(
input_it
.
mem_desc
());
srcs_md
.
push_back
(
dnnl
::
memory
::
desc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
));
++
num_inputs_
;
++
num_inputs_
;
}
}
std
::
vector
<
float
>
scales
(
num_inputs_
,
1.0
);
std
::
vector
<
float
>
scales
(
num_inputs_
,
1.0
f
);
auto
dst_md
=
dnnl
::
memory
::
desc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
auto
dst_md
=
dnnl
::
memory
::
desc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
any
);
MKLDNNMemoryFormat
::
any
);
...
@@ -139,47 +138,27 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -139,47 +138,27 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
++
input_index
;
++
input_index
;
}
}
std
::
shared_ptr
<
dnnl
::
memory
>
dst_mem
=
nullptr
;
std
::
unordered_map
<
int
,
dnnl
::
memory
>
args
;
std
::
shared_ptr
<
dnnl
::
memory
>
dst_mem
;
for
(
size_t
i
=
0
;
i
<
srcs_mem
.
size
();
++
i
)
{
args
.
insert
({
DNNL_ARG_MULTIPLE_SRC
+
i
,
*
(
srcs_mem
[
i
])});
}
if
(
in_place
)
{
if
(
in_place
)
{
dst_mem
=
handler
.
AcquireDstMemory
();
dst_mem
=
srcs_mem
[
0
];
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
else
{
}
else
{
dst_mem
=
handler
.
AcquireDstMemory
(
output
);
dst_mem
=
handler
.
AcquireDstMemory
(
output
);
}
}
args
.
insert
({
DNNL_ARG_DST
,
*
dst_mem
});
auto
sum_p
=
handler
.
AcquireForwardPrimitive
();
auto
sum_p
=
handler
.
AcquireForwardPrimitive
();
std
::
unordered_map
<
int
,
dnnl
::
memory
>
args
;
for
(
size_t
i
=
0
;
i
<
srcs_mem
.
size
();
++
i
)
{
args
.
insert
({
DNNL_ARG_MULTIPLE_SRC
+
i
,
*
(
srcs_mem
[
i
])});
}
args
.
insert
({
DNNL_ARG_DST
,
*
dst_mem
});
auto
&
astream
=
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
auto
&
astream
=
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
sum_p
->
execute
(
astream
,
args
);
sum_p
->
execute
(
astream
,
args
);
astream
.
wait
();
astream
.
wait
();
// For in-place execution which sum does not have we need to fake it
output
->
set_mem_desc
(
dst_mem
->
get_desc
());
// so from oneDNN dst memory we reorder data into input
if
(
in_place
)
{
auto
&
in_out
=
in_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
();
auto
output_tz
=
phi
::
vectorize
<
int64_t
>
(
output
->
dims
());
platform
::
ReorderMKLDNNHandler
reorder_handler
(
output_tz
,
framework
::
TransToProtoVarType
(
output
->
dtype
()),
framework
::
ToMKLDNNDataType
(
framework
::
TransToProtoVarType
(
in_out
.
dtype
())),
dev_ctx
.
GetEngine
());
auto
target_mem
=
reorder_handler
.
AcquireDstMemory
(
output
,
in_out
.
format
(),
ctx
.
GetPlace
());
auto
reorder_p
=
reorder_handler
.
AcquireReorder
(
target_mem
,
dst_mem
);
reorder_p
->
execute
(
astream
,
*
dst_mem
,
*
target_mem
);
astream
.
wait
();
}
output
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
output
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_mem
));
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录