Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
56008aa1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
56008aa1
编写于
5月 19, 2021
作者:
J
Jacek Czaja
提交者:
GitHub
5月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[oneDNN] Pool softmax and LRN access to cache optimized (#32922)
上级
af89a943
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
175 addition
and
120 deletion
+175
-120
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
+102
-31
paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc
paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc
+19
-5
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
+7
-5
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+45
-79
python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
...paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
+2
-0
未找到文件。
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
浏览文件 @
56008aa1
...
@@ -14,21 +14,104 @@ limitations under the License. */
...
@@ -14,21 +14,104 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace
paddle
{
namespace
framework
{
class
Tensor
;
}
// namespace framework
namespace
platform
{
class
MKLDNNDeviceContext
;
}
// namespace platform
}
// namespace paddle
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
paddle
::
framework
::
Tensor
;
using
paddle
::
framework
::
Tensor
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
template
<
typename
T
>
class
LRNMKLDNNHandler
:
public
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
{
public:
LRNMKLDNNHandler
(
const
framework
::
ExecutionContext
&
ctx
,
const
MKLDNNDeviceContext
&
dev_ctx
,
const
mkldnn
::
engine
mkldnn_engine
,
platform
::
Place
cpu_place
,
const
Tensor
*
input
,
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
(
dev_ctx
,
mkldnn_engine
,
cpu_place
,
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
unique_name
))
{
if
(
!
this
->
isCachedNonBlocking
())
{
const
int
n
=
ctx
.
Attr
<
int
>
(
"n"
);
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const
float
alpha
=
ctx
.
Attr
<
float
>
(
"alpha"
)
*
static_cast
<
float
>
(
n
);
const
float
beta
=
ctx
.
Attr
<
float
>
(
"beta"
);
const
float
k
=
ctx
.
Attr
<
float
>
(
"k"
);
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
auto
dims
=
framework
::
vectorize
(
input
->
dims
());
auto
src_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
input
->
format
());
this
->
AcquireForwardPrimitiveDescriptorNonBlocking
(
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
,
mkldnn
::
algorithm
::
lrn_across_channels
,
src_md
,
n
,
alpha
,
beta
,
k
);
}
}
LRNMKLDNNHandler
(
const
framework
::
ExecutionContext
&
ctx
,
const
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
Tensor
*
in_x
,
const
Tensor
*
out_grad
,
Tensor
*
in_x_grad
,
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
in_x
->
dims
()),
unique_name
))
{
if
(
!
this
->
isBwdCached
())
{
PADDLE_ENFORCE_EQ
(
ctx
.
Attr
<
bool
>
(
"is_test"
),
false
,
platform
::
errors
::
PreconditionNotMet
(
"is_test attribute should be set to False in training phase."
));
const
int
n
=
ctx
.
Attr
<
int
>
(
"n"
);
const
float
alpha
=
ctx
.
Attr
<
float
>
(
"alpha"
)
*
static_cast
<
float
>
(
n
);
const
float
beta
=
ctx
.
Attr
<
float
>
(
"beta"
);
const
float
k
=
ctx
.
Attr
<
float
>
(
"k"
);
auto
dims
=
framework
::
vectorize
<
int64_t
>
(
in_x
->
dims
());
auto
src_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
in_x
->
format
());
auto
diff_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out_grad
->
format
());
this
->
AcquireForwardPrimitiveDescriptorNonBlocking
(
mkldnn
::
prop_kind
::
forward_training
,
mkldnn
::
algorithm
::
lrn_across_channels
,
src_md
,
n
,
alpha
,
beta
,
k
);
this
->
AcquireBackwardPrimitiveDescriptorNonBlocking
(
mkldnn
::
algorithm
::
lrn_across_channels
,
src_md
,
diff_md
,
n
,
alpha
,
beta
,
k
);
}
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWorkspaceMemory
(
Tensor
*
workspace
)
{
T
*
ptr
=
workspace
->
mutable_data
<
T
>
(
this
->
place_
,
this
->
fwd_pd_
->
workspace_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
workspace_desc
(),
ptr
,
"@wrk_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireBackwardWorkspaceMemory
(
const
Tensor
*
workspace
)
{
const
T
*
workspace_data
=
workspace
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
workspace_desc
(),
platform
::
to_void_cast
<
T
>
(
workspace_data
),
"@bwd-wrk_mem_p"
);
}
};
template
<
typename
T
>
template
<
typename
T
>
class
LRNMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
class
LRNMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -48,8 +131,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -48,8 +131,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
mid
=
ctx
.
Output
<
Tensor
>
(
"MidOut"
);
auto
mid
=
ctx
.
Output
<
Tensor
>
(
"MidOut"
);
platform
::
LRNMKLDNNHandler
<
T
>
handler
(
LRNMKLDNNHandler
<
T
>
handler
(
ctx
,
dev_ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
,
ctx
,
dev_ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
,
ctx
.
OutputName
(
"Out"
));
ctx
.
OutputName
(
"Out"
));
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
);
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
);
auto
dst_memory
=
handler
.
AcquireDstMemory
(
out
);
auto
dst_memory
=
handler
.
AcquireDstMemory
(
out
);
...
@@ -87,34 +170,22 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -87,34 +170,22 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
paddle
::
platform
::
errors
::
PreconditionNotMet
(
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"Operator DNNL LRNGrad must use CPUPlace"
));
"Operator DNNL LRNGrad must use CPUPlace"
));
PADDLE_ENFORCE_EQ
(
ctx
.
Attr
<
bool
>
(
"is_test"
),
false
,
platform
::
errors
::
PreconditionNotMet
(
"is_test attribute should be set to False in training phase."
));
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
in_
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
mid
=
ctx
.
Input
<
Tensor
>
(
"MidOut"
);
auto
mid
=
ctx
.
Input
<
Tensor
>
(
"MidOut"
);
auto
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
in_x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
const
int
n
=
ctx
.
Attr
<
int
>
(
"n"
);
const
float
alpha
=
ctx
.
Attr
<
float
>
(
"alpha"
)
*
static_cast
<
float
>
(
n
);
const
float
beta
=
ctx
.
Attr
<
float
>
(
"beta"
);
const
float
k
=
ctx
.
Attr
<
float
>
(
"k"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
dims
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
x
->
dims
());
LRNMKLDNNHandler
<
T
>
handler
(
ctx
,
dev_ctx
,
ctx
.
GetPlace
(),
in_x
,
out_grad
,
in_x_grad
,
ctx
.
InputName
(
"Out"
));
platform
::
LRNMKLDNNHandler
<
T
>
handler
(
dims
,
n
,
alpha
,
beta
,
k
,
x
->
format
(),
auto
src_memory
=
handler
.
AcquireSrcMemory
(
in_x
);
out_grad
->
format
(),
dev_ctx
,
ctx
.
GetPlace
(),
ctx
.
InputName
(
"Out"
));
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
);
auto
workspace
=
handler
.
AcquireBackwardWorkspaceMemory
(
mid
);
auto
workspace
=
handler
.
AcquireBackwardWorkspaceMemory
(
mid
);
auto
diff_dst_memory
=
handler
.
AcquireDiffDstMemory
(
out_grad
);
auto
diff_dst_memory
=
handler
.
AcquireDiffDstMemory
(
out_grad
);
auto
diff_src_memory
=
handler
.
AcquireDiffSrcMemory
(
x_grad
);
auto
diff_src_memory
=
handler
.
AcquireDiffSrcMemory
(
in_
x_grad
);
auto
lrn_bwd
=
handler
.
AcquireBackwardPrimitive
();
auto
lrn_bwd
=
handler
.
AcquireBackwardPrimitive
();
...
@@ -125,8 +196,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -125,8 +196,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
{
MKLDNN_ARG_WORKSPACE
,
*
workspace
}});
{
MKLDNN_ARG_WORKSPACE
,
*
workspace
}});
astream
.
wait
();
astream
.
wait
();
x_grad
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
in_
x_grad
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
x_grad
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
diff_src_memory
));
in_
x_grad
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
diff_src_memory
));
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc
浏览文件 @
56008aa1
...
@@ -43,7 +43,7 @@ class PoolingMKLDNNHandler
...
@@ -43,7 +43,7 @@ class PoolingMKLDNNHandler
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
framework
::
ToMKLDNNDataType
(
input
->
type
()),
framework
::
ToMKLDNNDataType
(
input
->
type
()),
unique_name
))
{
unique_name
))
{
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
NonBlocking
())
{
PADDLE_ENFORCE_EQ
(
input
->
layout
(),
DataLayout
::
kMKLDNN
,
PADDLE_ENFORCE_EQ
(
input
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Wrong layout set for Input tensor."
));
"Wrong layout set for Input tensor."
));
...
@@ -100,11 +100,10 @@ class PoolingMKLDNNHandler
...
@@ -100,11 +100,10 @@ class PoolingMKLDNNHandler
const
auto
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
auto
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
auto
dt
=
framework
::
ToMKLDNNDataType
(
input
->
type
());
const
auto
dt
=
framework
::
ToMKLDNNDataType
(
input
->
type
());
const
auto
fmt
=
input
->
format
();
const
auto
exclude_padding
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
const
auto
exclude_padding
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
const
auto
src_md
=
mkldnn
::
memory
::
desc
(
src_tz
,
dt
,
fmt
);
const
auto
src_md
=
mkldnn
::
memory
::
desc
(
src_tz
,
dt
,
input
->
format
()
);
/* create memory descriptor for pooling without specified format
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
* the memory format preferred for best performance
...
@@ -124,7 +123,7 @@ class PoolingMKLDNNHandler
...
@@ -124,7 +123,7 @@ class PoolingMKLDNNHandler
ComputeAdaptivePoolParameters
(
ctx
,
src_tz
,
&
ksize
,
&
strides
);
ComputeAdaptivePoolParameters
(
ctx
,
src_tz
,
&
ksize
,
&
strides
);
this
->
AcquireForwardPrimitiveDescriptor
(
this
->
AcquireForwardPrimitiveDescriptor
NonBlocking
(
is_test
?
mkldnn
::
prop_kind
::
forward_inference
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
,
:
mkldnn
::
prop_kind
::
forward_training
,
pooling_type
==
"max"
pooling_type
==
"max"
...
@@ -200,6 +199,10 @@ class PoolingMKLDNNHandler
...
@@ -200,6 +199,10 @@ class PoolingMKLDNNHandler
auto
diff_dst_tz
=
auto
diff_dst_tz
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
out_grad
->
dims
());
paddle
::
framework
::
vectorize
<
int64_t
>
(
out_grad
->
dims
());
const
auto
dt
=
framework
::
ToMKLDNNDataType
(
in_x
->
type
());
auto
src_md
=
mkldnn
::
memory
::
desc
(
src_tz
,
dt
,
in_x
->
format
());
auto
dst_md
=
mkldnn
::
memory
::
desc
(
diff_dst_tz
,
dt
,
MKLDNNMemoryFormat
::
any
);
auto
diff_dst_md
=
mkldnn
::
memory
::
desc
(
auto
diff_dst_md
=
mkldnn
::
memory
::
desc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out_grad
->
format
());
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out_grad
->
format
());
auto
diff_src_md
=
auto
diff_src_md
=
...
@@ -216,7 +219,18 @@ class PoolingMKLDNNHandler
...
@@ -216,7 +219,18 @@ class PoolingMKLDNNHandler
ComputeAdaptivePoolParameters
(
ctx
,
diff_src_tz
,
&
ksize
,
&
strides
);
ComputeAdaptivePoolParameters
(
ctx
,
diff_src_tz
,
&
ksize
,
&
strides
);
const
auto
exclude_padding
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
const
auto
exclude_padding
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
this
->
AcquireBackwardPrimitiveDescriptor
(
this
->
AcquireForwardPrimitiveDescriptorNonBlocking
(
mkldnn
::
prop_kind
::
forward_training
,
pooling_type
==
"max"
?
mkldnn
::
algorithm
::
pooling_max
:
(
exclude_padding
?
mkldnn
::
algorithm
::
pooling_avg_exclude_padding
:
mkldnn
::
algorithm
::
pooling_avg_include_padding
),
src_md
,
dst_md
,
strides
,
ksize
,
mkldnn_paddings
[
0
],
mkldnn_paddings
[
1
]);
this
->
AcquireBackwardPrimitiveDescriptorNonBlocking
(
pooling_type
==
"max"
pooling_type
==
"max"
?
mkldnn
::
algorithm
::
pooling_max
?
mkldnn
::
algorithm
::
pooling_max
:
(
exclude_padding
:
(
exclude_padding
...
...
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
浏览文件 @
56008aa1
...
@@ -50,7 +50,7 @@ class SoftmaxMKLDNNHandler
...
@@ -50,7 +50,7 @@ class SoftmaxMKLDNNHandler
:
platform
::
CreateKey
(
:
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
uniq_name
))
{
uniq_name
))
{
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
NonBlocking
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
input
->
dims
(),
output
->
dims
(),
input
->
dims
(),
output
->
dims
(),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -60,8 +60,8 @@ class SoftmaxMKLDNNHandler
...
@@ -60,8 +60,8 @@ class SoftmaxMKLDNNHandler
auto
md
=
memory
::
desc
(
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
auto
md
=
memory
::
desc
(
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
input
->
format
());
input
->
format
());
this
->
AcquireForwardPrimitiveDescriptor
(
prop_kind
::
forward_scoring
,
md
,
this
->
AcquireForwardPrimitiveDescriptor
NonBlocking
(
axis
);
prop_kind
::
forward_scoring
,
md
,
axis
);
}
}
}
}
...
@@ -90,8 +90,10 @@ class SoftmaxMKLDNNHandler
...
@@ -90,8 +90,10 @@ class SoftmaxMKLDNNHandler
auto
diff_softmax_md
=
MKLDNNMemDesc
(
auto
diff_softmax_md
=
MKLDNNMemDesc
(
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out_grad
->
format
());
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out_grad
->
format
());
this
->
AcquireBackwardPrimitiveDescriptor
(
diff_softmax_md
,
data_softmax_md
,
this
->
AcquireForwardPrimitiveDescriptorNonBlocking
(
axis
);
prop_kind
::
forward_scoring
,
data_softmax_md
,
axis
);
this
->
AcquireBackwardPrimitiveDescriptorNonBlocking
(
diff_softmax_md
,
data_softmax_md
,
axis
);
}
}
}
}
};
};
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
56008aa1
...
@@ -126,13 +126,20 @@ class MKLDNNHandlerT {
...
@@ -126,13 +126,20 @@ class MKLDNNHandlerT {
return
(
dev_ctx_
.
GetBlob
(
key_p
)
!=
nullptr
);
return
(
dev_ctx_
.
GetBlob
(
key_p
)
!=
nullptr
);
}
}
bool
isCachedNonBlocking
()
{
const
std
::
string
key_pd
=
key_
+
"@fwd_pd"
;
fwd_pd_
=
std
::
static_pointer_cast
<
typename
TForward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_pd
));
return
(
fwd_pd_
!=
nullptr
);
}
bool
isBwdCached
()
{
bool
isBwdCached
()
{
const
std
::
string
key_pd
=
key_
common_
+
"@bwd_pd"
;
const
std
::
string
key_pd
=
key_
+
"@bwd_pd"
;
bwd_pd_
=
std
::
static_pointer_cast
<
typename
TBackward
::
primitive_desc
>
(
bwd_pd_
=
std
::
static_pointer_cast
<
typename
TBackward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_pd
));
dev_ctx_
.
GetBlob
(
key_pd
));
const
std
::
string
key_p
=
key_
+
"@bwd_p"
;
return
(
bwd_pd_
!=
nullptr
);
return
(
dev_ctx_
.
GetBlob
(
key_p
)
!=
nullptr
);
}
}
// If your primitive descriptor requires attributes, pass them as a
// If your primitive descriptor requires attributes, pass them as a
...
@@ -161,6 +168,20 @@ class MKLDNNHandlerT {
...
@@ -161,6 +168,20 @@ class MKLDNNHandlerT {
}
}
}
}
template
<
typename
Arg
,
typename
...
Args
>
void
AcquireForwardPrimitiveDescriptorNonBlocking
(
Arg
&&
first_arg
,
Args
&&
...
args
)
{
// This is used when we can recreate FWD PD in BWD so
// we do not need to pass FWD to BWD
const
std
::
string
key_pd
=
key_
+
"@fwd_pd"
;
fwd_pd_
=
std
::
static_pointer_cast
<
typename
TForward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_pd
));
if
(
fwd_pd_
==
nullptr
)
{
CreateForwardPrimitiveDescriptor
(
first_arg
,
std
::
forward
<
Args
>
(
args
)...);
dev_ctx_
.
SetBlob
(
key_pd
,
fwd_pd_
);
}
}
// Using sfinae to specialise variadic function. Workaround for not having
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
// if constexpr in C++ 11.
template
<
class
First
,
class
...
Args
>
template
<
class
First
,
class
...
Args
>
...
@@ -182,6 +203,8 @@ class MKLDNNHandlerT {
...
@@ -182,6 +203,8 @@ class MKLDNNHandlerT {
std
::
make_shared
<
typename
TForward
::
primitive_desc
>
(
fwd_desc
,
engine_
);
std
::
make_shared
<
typename
TForward
::
primitive_desc
>
(
fwd_desc
,
engine_
);
}
}
// TODO(jczaja): After/if all ops can used xxxNonBlocking version
// then remove this one
template
<
typename
...
Args
>
template
<
typename
...
Args
>
void
AcquireBackwardPrimitiveDescriptor
(
Args
&&
...
args
)
{
void
AcquireBackwardPrimitiveDescriptor
(
Args
&&
...
args
)
{
const
std
::
string
key_fwd_pd
=
key_common_
+
"@fwd_pd"
;
const
std
::
string
key_fwd_pd
=
key_common_
+
"@fwd_pd"
;
...
@@ -201,6 +224,25 @@ class MKLDNNHandlerT {
...
@@ -201,6 +224,25 @@ class MKLDNNHandlerT {
}
}
}
}
template
<
typename
...
Args
>
void
AcquireBackwardPrimitiveDescriptorNonBlocking
(
Args
&&
...
args
)
{
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptorNonBlocking
PADDLE_ENFORCE_NOT_NULL
(
fwd_pd_
,
platform
::
errors
::
Unavailable
(
"Get MKLDNN Forward primitive %s failed."
,
key_
+
"@fwd_pd"
));
const
std
::
string
key_pd
=
key_
+
"@bwd_pd"
;
bwd_pd_
=
std
::
static_pointer_cast
<
typename
TBackward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_pd
));
if
(
bwd_pd_
==
nullptr
)
{
auto
bwd_desc
=
typename
TBackward
::
desc
(
std
::
forward
<
Args
>
(
args
)...);
bwd_pd_
=
std
::
make_shared
<
typename
TBackward
::
primitive_desc
>
(
bwd_desc
,
engine_
,
*
fwd_pd_
);
dev_ctx_
.
SetBlob
(
key_pd
,
bwd_pd_
);
}
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryFromPrimitive
(
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryFromPrimitive
(
const
std
::
string
&
suffix
)
{
const
std
::
string
&
suffix
)
{
return
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
return
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
...
@@ -781,82 +823,6 @@ class ActivationMKLDNNHandler
...
@@ -781,82 +823,6 @@ class ActivationMKLDNNHandler
}
}
};
};
template
<
typename
T
>
class
LRNMKLDNNHandler
:
public
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
{
public:
LRNMKLDNNHandler
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
mkldnn
::
engine
mkldnn_engine
,
platform
::
Place
cpu_place
,
const
Tensor
*
input
,
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
(
dev_ctx
,
mkldnn_engine
,
cpu_place
,
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
unique_name
))
{
if
(
!
this
->
isCached
())
{
const
int
n
=
ctx
.
Attr
<
int
>
(
"n"
);
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const
float
alpha
=
ctx
.
Attr
<
float
>
(
"alpha"
)
*
static_cast
<
float
>
(
n
);
const
float
beta
=
ctx
.
Attr
<
float
>
(
"beta"
);
const
float
k
=
ctx
.
Attr
<
float
>
(
"k"
);
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
auto
dims
=
paddle
::
framework
::
vectorize
(
input
->
dims
());
auto
src_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
input
->
format
());
this
->
AcquireForwardPrimitiveDescriptor
(
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
,
mkldnn
::
algorithm
::
lrn_across_channels
,
src_md
,
n
,
alpha
,
beta
,
k
);
}
}
LRNMKLDNNHandler
(
const
std
::
vector
<
int64_t
>&
dims
,
const
int
n
,
const
float
alpha
,
const
float
beta
,
const
float
k
,
const
MKLDNNMemoryFormat
fmt
,
const
MKLDNNMemoryFormat
diff_fmt
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dev_ctx
,
dims
,
unique_name
))
{
auto
src_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
diff_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
this
->
AcquireBackwardPrimitiveDescriptor
(
mkldnn
::
algorithm
::
lrn_across_channels
,
src_md
,
diff_md
,
n
,
alpha
,
beta
,
k
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWorkspaceMemory
(
framework
::
Tensor
*
workspace
)
{
T
*
ptr
=
workspace
->
mutable_data
<
T
>
(
this
->
place_
,
this
->
fwd_pd_
->
workspace_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
workspace_desc
(),
ptr
,
"@wrk_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireBackwardWorkspaceMemory
(
const
framework
::
Tensor
*
workspace
)
{
const
T
*
workspace_data
=
workspace
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
workspace_desc
(),
to_void_cast
<
T
>
(
workspace_data
),
"@bwd-wrk_mem_p"
);
}
};
template
<
typename
T
>
template
<
typename
T
>
class
TransposeMKLDNNHandler
:
public
MKLDNNHandler
{
class
TransposeMKLDNNHandler
:
public
MKLDNNHandler
{
public:
public:
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
浏览文件 @
56008aa1
...
@@ -63,4 +63,6 @@ class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
...
@@ -63,4 +63,6 @@ class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
from
paddle
import
enable_static
enable_static
()
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录