Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fb4b4f8d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fb4b4f8d
编写于
8月 27, 2018
作者:
K
Krzysztof Binias
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor code
上级
50d3e6e9
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
26 addition
and
53 deletion
+26
-53
paddle/fluid/operators/batch_norm_mkldnn_op.cc
paddle/fluid/operators/batch_norm_mkldnn_op.cc
+26
-53
未找到文件。
paddle/fluid/operators/batch_norm_mkldnn_op.cc
浏览文件 @
fb4b4f8d
...
...
@@ -62,56 +62,42 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
batch_norm_pd_
->
variance_primitive_desc
(),
ptr
,
"@variance_mem_p"
);
}
std
::
shared_ptr
<
batch_norm_fwd
>
AcquireTestBatchNormFwd
(
std
::
shared_ptr
<
batch_norm_fwd
>
AcquireTest
Training
BatchNormFwd
(
std
::
shared_ptr
<
memory
>
src_memory
,
const
mkldnn
::
primitive
::
at
&
mean_memory
,
const
mkldnn
::
primitive
::
at
&
variance_memory
,
std
::
shared_ptr
<
memory
>
scaleshift_memory
,
std
::
shared_ptr
<
memory
>
dst_memory
)
{
std
::
shared_ptr
<
memory
>
dst_memory
,
std
::
shared_ptr
<
memory
>
mean_memory
,
std
::
shared_ptr
<
memory
>
variance_memory
,
bool
is_test
)
{
auto
prim_key
=
key_
+
"@batch_norm_p"
;
auto
batch_norm_p
=
std
::
static_pointer_cast
<
batch_norm_fwd
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
(
(
batch_norm_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find batch norm primitive for test in device context"
);
if
(
batch_norm_p
==
nullptr
)
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
mean_memory
,
variance_memory
,
*
scaleshift_memory
,
*
dst_memory
);
dev_ctx_
.
SetBlob
(
prim_key
,
batch_norm_p
);
}
else
{
is_reusing_
=
true
;
}
return
batch_norm_p
;
}
PADDLE_ENFORCE
((
batch_norm_p
!=
nullptr
)
||
!
is_reusing_
,
"Fail to find batch norm primitive in device context"
);
std
::
shared_ptr
<
batch_norm_fwd
>
AcquireTrainingBatchNormFwd
(
std
::
shared_ptr
<
memory
>
src_memory
,
std
::
shared_ptr
<
memory
>
scaleshift_memory
,
std
::
shared_ptr
<
memory
>
dst_memory
,
std
::
shared_ptr
<
memory
>
mean_memory
,
std
::
shared_ptr
<
memory
>
variance_memory
)
{
auto
prim_key
=
key_
+
"@batch_norm_p"
;
auto
batch_norm_p
=
std
::
static_pointer_cast
<
batch_norm_fwd
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
(
(
batch_norm_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find batch norm primitive for training in device context"
);
if
(
batch_norm_p
==
nullptr
)
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
*
scaleshift_memory
,
*
dst_memory
,
*
mean_memory
,
*
variance_memory
);
if
(
is_test
)
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
variance_memory
,
*
scaleshift_memory
,
*
dst_memory
);
}
else
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
*
scaleshift_memory
,
*
dst_memory
,
*
mean_memory
,
*
variance_memory
);
}
dev_ctx_
.
SetBlob
(
prim_key
,
batch_norm_p
);
}
else
{
is_reusing_
=
true
;
}
return
batch_norm_p
;
}
//
static
std
::
string
GetHash
(
const
memory
::
dims
&
input_dims
,
float
epsilon
,
unsigned
flag
,
bool
is_test
,
memory
::
format
format
,
const
std
::
string
&
suffix
)
{
const
std
::
string
&
suffix
=
""
)
{
auto
dims2str
=
[](
const
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
...
...
@@ -128,19 +114,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_pd_
;
};
std
::
string
gethash
(
const
memory
::
dims
&
input_dims
,
float
epsilon
,
unsigned
flag
,
bool
is_test
,
memory
::
format
format
)
{
auto
dims2str
=
[](
const
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
};
return
dims2str
(
input_dims
)
+
std
::
to_string
(
epsilon
)
+
std
::
to_string
(
flag
)
+
std
::
to_string
(
is_test
)
+
std
::
to_string
(
format
);
}
std
::
shared_ptr
<
memory
>
UpdateMemoryData
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
std
::
string
&
key
,
void
*
new_ptr
)
{
...
...
@@ -274,10 +247,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler
.
AcquireVarianceMemoryFromPrimitive
(
to_void_cast
(
variance_data
));
batch_norm_p
=
handler
.
AcquireTestBatchNormFwd
(
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
variance_memory
,
scaleshift_memory
,
dst_memory
);
batch_norm_p
=
handler
.
AcquireTestTrainingBatchNormFwd
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean_memory
,
variance_memory
,
true
);
}
else
{
// create mkldnn memory for stats (as output)
std
::
shared_ptr
<
memory
>
mean_memory
=
...
...
@@ -285,9 +257,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemoryFromPrimitive
(
batch_variance_data
);
batch_norm_p
=
handler
.
AcquireTrainingBatchNormFwd
(
batch_norm_p
=
handler
.
AcquireT
estT
rainingBatchNormFwd
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean_memory
,
variance_memory
);
variance_memory
,
false
);
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
...
...
@@ -377,7 +349,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// keys for primitives reuse
const
std
::
string
key_with_hash
=
key
+
gethash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
);
key
+
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
);
const
std
::
string
key_batch_norm_bwd_p
=
key_with_hash
+
"@batch_norm_bwd_p"
;
const
std
::
string
key_batch_norm_src_mem_p
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录