Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
98270c18
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
98270c18
编写于
8月 05, 2021
作者:
J
Jacek Czaja
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
- modified UT
上级
2b24a801
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
220 addition
and
241 deletion
+220
-241
paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
...luid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
+3
-1
paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
...operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
+4
-4
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
+6
-5
paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc
paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc
+4
-2
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
+27
-23
paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc
paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc
+20
-46
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+156
-160
未找到文件。
paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
浏览文件 @
98270c18
...
...
@@ -47,7 +47,9 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
float
scale_o
=
ctx
.
Attr
<
float
>
(
"Scale_out"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
platform
::
BinaryMKLDNNHandler
<
T
>
handler
(
BINARY_OP
,
axis
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
,
y
,
z
,
scale_x
,
scale_y
,
scale_o
);
platform
::
BinaryMKLDNNHandler
<
T
>
handler
(
BINARY_OP
,
axis
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
,
y
,
z
,
scale_x
,
scale_y
,
scale_o
);
const
auto
src_x_memory
=
handler
.
AcquireSrcMemory
(
x
);
const
auto
src_y_memory
=
handler
.
AcquireSecondSrcMemory
(
y
);
...
...
paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
浏览文件 @
98270c18
...
...
@@ -48,8 +48,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
if
(
dx
)
{
// dx = dout*y
platform
::
BinaryMKLDNNHandler
<
T
>
handler
(
dnnl
::
algorithm
::
binary_mul
,
axis
,
mkldnn_engine
,
ctx
.
GetPlace
(),
dout
,
y
,
dx
,
1.0
f
,
1.0
f
,
1.0
f
);
dnnl
::
algorithm
::
binary_mul
,
axis
,
mkldnn_engine
,
ctx
.
GetPlace
(),
dout
,
y
,
dx
,
1.0
f
,
1.0
f
,
1.0
f
);
const
auto
src_dout_memory
=
handler
.
AcquireSrcMemory
(
dout
);
const
auto
src_y_memory
=
handler
.
AcquireSecondSrcMemory
(
y
);
...
...
@@ -74,8 +74,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor
platform
::
BinaryMKLDNNHandler
<
T
>
handler
(
dnnl
::
algorithm
::
binary_mul
,
axis
,
mkldnn_engine
,
ctx
.
GetPlace
(),
dout
,
x
,
nullptr
,
1.0
f
,
1.0
f
,
1.0
f
);
dnnl
::
algorithm
::
binary_mul
,
axis
,
mkldnn_engine
,
ctx
.
GetPlace
(),
dout
,
x
,
nullptr
,
1.0
f
,
1.0
f
,
1.0
f
);
const
auto
src_dout_memory
=
handler
.
AcquireSrcMemory
(
dout
);
const
auto
src_x_memory
=
handler
.
AcquireSecondSrcMemory
(
x
);
...
...
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
浏览文件 @
98270c18
...
...
@@ -79,14 +79,15 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"Operator DNNL eletwise_forward must use CPUPlace"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
bool
is_inplaced
=
x
->
IsSharedBufferWith
(
*
y
);
platform
::
ActivationMKLDNNHandler
<
T
>
handler
(
algorithm
,
ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
);
platform
::
ActivationMKLDNNHandler
<
T
>
handler
(
algorithm
,
ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
auto
dst_memory_p
=
is_inplaced
?
src_memory_p
:
handler
.
AcquireDstMemory
(
y
);
...
...
@@ -105,14 +106,14 @@ template <typename T>
void
eltwise_grad
(
const
framework
::
ExecutionContext
&
ctx
,
mkldnn
::
algorithm
algorithm
)
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
diff_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
diff_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
platform
::
ActivationMKLDNNHandler
<
T
>
handler
(
algorithm
,
ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
,
diff_y
);
platform
::
ActivationMKLDNNHandler
<
T
>
handler
(
algorithm
,
ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
,
diff_y
);
auto
src_memory_p
=
handler
.
AcquireBackwardSrcMemory
(
x
);
auto
diff_dst_memory_p
=
handler
.
AcquireDiffDstMemory
(
diff_y
);
...
...
paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc
浏览文件 @
98270c18
...
...
@@ -37,10 +37,12 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
bool
is_inplaced
=
x
->
IsSharedBufferWith
(
*
out
);
platform
::
ActivationMKLDNNHandler
<
T
>
handler
(
mkldnn
::
algorithm
::
eltwise_linear
,
ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
);
mkldnn
::
algorithm
::
eltwise_linear
,
ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
x
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
auto
dst_memory_p
=
is_inplaced
?
src_memory_p
:
handler
.
AcquireDstMemory
(
out
);
auto
dst_memory_p
=
is_inplaced
?
src_memory_p
:
handler
.
AcquireDstMemory
(
out
);
auto
activation_p
=
handler
.
AcquireForwardPrimitive
();
auto
&
astream
=
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
...
...
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
浏览文件 @
98270c18
...
...
@@ -33,12 +33,13 @@ using platform::to_void_cast;
template
<
typename
T
>
class
SoftmaxMKLDNNHandler
:
public
platform
::
MKLDNNHandlerNoCachingT
<
T
,
mkldnn
::
softmax_forward
,
mkldnn
::
softmax_backward
>
{
mkldnn
::
softmax_backward
>
{
public:
SoftmaxMKLDNNHandler
(
const
mkldnn
::
engine
mkldnn_engine
,
platform
::
Place
cpu_place
,
const
Tensor
*
input
,
Tensor
*
output
,
const
int
axis
)
:
platform
::
MKLDNNHandlerNoCachingT
<
T
,
mkldnn
::
softmax_forward
,
mkldnn
::
softmax_backward
>
(
:
platform
::
MKLDNNHandlerNoCachingT
<
T
,
mkldnn
::
softmax_forward
,
mkldnn
::
softmax_backward
>
(
mkldnn_engine
,
cpu_place
)
{
PADDLE_ENFORCE_EQ
(
input
->
dims
(),
output
->
dims
(),
...
...
@@ -49,7 +50,8 @@ class SoftmaxMKLDNNHandler
auto
md
=
memory
::
desc
(
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
input
->
format
());
this
->
AcquireForwardPrimitiveDescriptor
(
prop_kind
::
forward_scoring
,
md
,
axis
);
this
->
AcquireForwardPrimitiveDescriptor
(
prop_kind
::
forward_scoring
,
md
,
axis
);
}
SoftmaxMKLDNNHandler
(
const
framework
::
ExecutionContext
&
ctx
,
...
...
@@ -58,25 +60,26 @@ class SoftmaxMKLDNNHandler
const
Tensor
*
out_grad
,
Tensor
*
in_x_grad
,
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandlerNoCachingT
<
T
,
mkldnn
::
softmax_forward
,
mkldnn
::
softmax_backward
>
(
mkldnn_engine
,
cpu_place
)
{
PADDLE_ENFORCE_EQ
(
out_grad
->
dims
(),
in_x_grad
->
dims
(),
platform
::
errors
::
InvalidArgument
(
"The shape of softmax_grad's input "
"and output must be identical."
));
auto
dims
=
out_grad
->
dims
();
// input and output share the same shape
const
int
axis
=
CanonicalAxis
(
ctx
.
Attr
<
int
>
(
"axis"
),
dims
.
size
());
auto
softmax_tz
=
framework
::
vectorize
<
int64_t
>
(
dims
);
auto
data_softmax_md
=
MKLDNNMemDesc
(
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out
->
format
());
auto
diff_softmax_md
=
MKLDNNMemDesc
(
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out_grad
->
format
());
this
->
AcquireForwardPrimitiveDescriptor
(
prop_kind
::
forward_scoring
,
data_softmax_md
,
axis
);
this
->
AcquireBackwardPrimitiveDescriptor
(
diff_softmax_md
,
data_softmax_md
,
axis
);
mkldnn
::
softmax_backward
>
(
mkldnn_engine
,
cpu_place
)
{
PADDLE_ENFORCE_EQ
(
out_grad
->
dims
(),
in_x_grad
->
dims
(),
platform
::
errors
::
InvalidArgument
(
"The shape of softmax_grad's input "
"and output must be identical."
));
auto
dims
=
out_grad
->
dims
();
// input and output share the same shape
const
int
axis
=
CanonicalAxis
(
ctx
.
Attr
<
int
>
(
"axis"
),
dims
.
size
());
auto
softmax_tz
=
framework
::
vectorize
<
int64_t
>
(
dims
);
auto
data_softmax_md
=
MKLDNNMemDesc
(
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out
->
format
());
auto
diff_softmax_md
=
MKLDNNMemDesc
(
softmax_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
out_grad
->
format
());
this
->
AcquireForwardPrimitiveDescriptor
(
prop_kind
::
forward_scoring
,
data_softmax_md
,
axis
);
this
->
AcquireBackwardPrimitiveDescriptor
(
diff_softmax_md
,
data_softmax_md
,
axis
);
}
};
...
...
@@ -93,7 +96,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const
int
axis
=
CanonicalAxis
(
ctx
.
Attr
<
int
>
(
"axis"
),
input
->
dims
().
size
());
SoftmaxMKLDNNHandler
<
T
>
handler
(
mkldnn_engine
,
ctx
.
GetPlace
(),
input
,
output
,
axis
);
SoftmaxMKLDNNHandler
<
T
>
handler
(
mkldnn_engine
,
ctx
.
GetPlace
(),
input
,
output
,
axis
);
auto
softmax_src_memory_p
=
handler
.
AcquireSrcMemory
(
input
);
// For Inplace src and and dst are the same memory object
...
...
paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc
浏览文件 @
98270c18
...
...
@@ -70,11 +70,16 @@ void RunOperator(const platform::Place &place, const std::string &op_type,
std
::
map
<
const
std
::
string
,
int
>
num_inputs
=
{{
"softmax"
,
1
},
{
"relu"
,
1
},
{
"conv2d"
,
2
},
{
"elementwise_add"
,
2
},
{
"elementwise_mul"
,
2
}};
std
::
string
first_input
=
inplace
==
true
?
output_name
:
"x"
;
std
::
string
first_input_var_name
=
(
op_type
==
"conv2d"
)
?
"Input"
:
"X"
;
std
::
string
second_input_var_name
=
(
op_type
==
"conv2d"
)
?
"Filter"
:
"Y"
;
std
::
string
output_var_name
=
(
op_type
==
"conv2d"
)
?
"Output"
:
"Out"
;
std
::
vector
<
InputVars
>
input_names
=
{
{
first_input
,
scope
.
Var
(
first_input
)
->
GetMutable
<
framework
::
LoDTensor
>
()},
{
"x1"
,
num_inputs
[
op_type
]
>
1
...
...
@@ -113,68 +118,37 @@ void RunOperator(const platform::Place &place, const std::string &op_type,
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
op
=
num_inputs
[
op_type
]
>
1
?
framework
::
OpRegistry
::
CreateOp
(
op_type
,
{{
"X"
,
{
first_input
}},
{
"Y"
,
{
"x1"
}}},
{{
"Out"
,
{
output_name
}}},
{{
"use_mkldnn"
,
{
true
}}})
:
framework
::
OpRegistry
::
CreateOp
(
op_type
,
{{
"X"
,
{
first_input
}}},
{{
"Out"
,
{
output_name
}}},
{{
"use_mkldnn"
,
{
true
}}});
auto
op
=
num_inputs
[
op_type
]
>
1
?
framework
::
OpRegistry
::
CreateOp
(
op_type
,
{{
first_input_var_name
,
{
first_input
}},
{
second_input_var_name
,
{
"x1"
}}},
{{
output_var_name
,
{
output_name
}}},
{{
"use_mkldnn"
,
{
true
}}})
:
framework
::
OpRegistry
::
CreateOp
(
op_type
,
{{
first_input_var_name
,
{
first_input
}}},
{{
output_var_name
,
{
output_name
}}},
{{
"use_mkldnn"
,
{
true
}}});
op
->
Run
(
scope
,
place
);
pool
.
Get
(
place
)
->
Wait
();
}
TEST
(
test_softmax_reuse_cache
,
cpu_place
)
{
framework
::
DDim
dims
({
32
,
64
});
framework
::
DDim
dims
({
1
,
16
,
32
,
64
});
platform
::
CPUPlace
p
;
CacheTester
ct
;
RunOperator
<
float
>
(
p
,
"
softmax"
,
dims
,
"softmax
_out"
);
RunOperator
<
float
>
(
p
,
"
softmax"
,
dims
,
"softmax
_out"
);
RunOperator
<
float
>
(
p
,
"
conv2d"
,
dims
,
"conv
_out"
);
RunOperator
<
float
>
(
p
,
"
conv2d"
,
dims
,
"conv
_out"
);
PADDLE_ENFORCE_EQ
(
ct
.
Analyze
(
4
),
true
,
platform
::
errors
::
InvalidArgument
(
"Wrong number of cached oneDNN objects"
));
}
TEST
(
test_softmax_noreuse_cache
,
cpu_place
)
{
framework
::
DDim
dims
({
32
,
64
});
platform
::
CPUPlace
p
;
CacheTester
ct
;
RunOperator
<
float
>
(
p
,
"softmax"
,
dims
,
"softmax_out"
);
RunOperator
<
float
>
(
p
,
"softmax"
,
dims
,
"softmax_out2"
);
PADDLE_ENFORCE_EQ
(
ct
.
Analyze
(
8
),
true
,
platform
::
errors
::
InvalidArgument
(
"Wrong number of cached oneDNN objects"
));
}
TEST
(
test_softmax_inplace_cache
,
cpu_place
)
{
framework
::
DDim
dims
({
32
,
64
});
platform
::
CPUPlace
p
;
CacheTester
ct
;
RunOperator
<
float
>
(
p
,
"softmax"
,
dims
,
"softmax_out"
);
RunOperator
<
float
>
(
p
,
"softmax"
,
dims
,
"softmax_out"
,
true
);
PADDLE_ENFORCE_EQ
(
ct
.
Analyze
(
7
),
true
,
platform
::
errors
::
InvalidArgument
(
"Wrong number of cached oneDNN objects"
));
}
TEST
(
test_relu_inplace_cache
,
cpu_place
)
{
framework
::
DDim
dims
({
32
,
64
});
platform
::
CPUPlace
p
;
CacheTester
ct
;
RunOperator
<
float
>
(
p
,
"relu"
,
dims
,
"relu_out"
);
RunOperator
<
float
>
(
p
,
"relu"
,
dims
,
"relu_out"
,
true
);
PADDLE_ENFORCE_EQ
(
ct
.
Analyze
(
7
),
true
,
platform
::
errors
::
InvalidArgument
(
"Wrong number of cached oneDNN objects"
));
}
TEST
(
test_elementwise_add_reuse_cache
,
cpu_place
)
{
framework
::
DDim
dims
({
32
,
64
});
framework
::
DDim
dims
({
1
,
16
,
32
,
64
});
platform
::
CPUPlace
p
;
CacheTester
ct
;
RunOperator
<
float
>
(
p
,
"
elementwise_add"
,
dims
,
"elementwise_add
_out"
);
RunOperator
<
float
>
(
p
,
"
relu"
,
dims
,
"elementwise_add_out"
,
true
);
RunOperator
<
float
>
(
p
,
"
conv2d"
,
dims
,
"conv
_out"
);
RunOperator
<
float
>
(
p
,
"
conv2d"
,
dims
,
"conv_out2"
);
PADDLE_ENFORCE_EQ
(
ct
.
Analyze
(
8
),
true
,
platform
::
errors
::
InvalidArgument
(
"Wrong number of cached oneDNN objects"
));
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
98270c18
...
...
@@ -34,40 +34,36 @@ using framework::Tensor;
using
user_function
=
std
::
function
<
std
::
shared_ptr
<
float
>
(
const
float
*
)
>
;
using
memory
=
mkldnn
::
memory
;
template
<
typename
T
,
typename
TForward
,
typename
TBackward
=
mkldnn_dummy_primitive
,
typename
TBackward_params
=
mkldnn_dummy_primitive
>
class
MKLDNNHandlerNoCachingT
{
public:
MKLDNNHandlerNoCachingT
(
mkldnn
::
engine
engine
,
platform
::
Place
cpu_place
)
:
engine_
(
engine
),
place_
(
cpu_place
),
fwd_pd_
(
nullptr
),
bwd_pd_
(
nullptr
)
{
:
engine_
(
engine
),
place_
(
cpu_place
),
fwd_pd_
(
nullptr
),
bwd_pd_
(
nullptr
)
{
platform
::
MKLDNNDeviceContext
::
tls
().
log_lib_version
();
}
std
::
shared_ptr
<
TForward
>
AcquireForwardPrimitive
()
{
return
std
::
make_shared
<
TForward
>
(
*
fwd_pd_
);
return
std
::
make_shared
<
TForward
>
(
*
fwd_pd_
);
}
std
::
shared_ptr
<
TBackward
>
AcquireBackwardPrimitive
()
{
return
std
::
make_shared
<
TBackward
>
(
*
bwd_pd_
);
return
std
::
make_shared
<
TBackward
>
(
*
bwd_pd_
);
}
std
::
shared_ptr
<
TBackward_params
>
AcquireBackwardWeightsPrimitive
()
{
PADDLE_ENFORCE_NOT_NULL
(
bwd_w_pd_
,
platform
::
errors
::
Unavailable
(
"Error: BWD_PD should be set when "
"getting BWD prim ."
));
return
std
::
make_shared
<
TBackward_params
>
(
*
bwd_w_pd_
);
PADDLE_ENFORCE_NOT_NULL
(
bwd_w_pd_
,
platform
::
errors
::
Unavailable
(
"Error: BWD_PD should be set when "
"getting BWD prim ."
));
return
std
::
make_shared
<
TBackward_params
>
(
*
bwd_w_pd_
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemory
(
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
fwd_pd_
->
src_desc
(),
to_void_cast
<
T
>
(
input_data
));
return
this
->
AcquireMemoryFromPrimitive
(
fwd_pd_
->
src_desc
(),
to_void_cast
<
T
>
(
input_data
));
}
template
<
typename
T_out
=
T
>
...
...
@@ -93,8 +89,8 @@ class MKLDNNHandlerNoCachingT {
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffDstMemory
(
const
framework
::
Tensor
*
diffdst
)
{
const
T
*
ptr
=
diffdst
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
diff_dst_desc
(),
to_void_cast
<
T
>
(
ptr
));
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
diff_dst_desc
(),
to_void_cast
<
T
>
(
ptr
));
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemory
(
...
...
@@ -113,7 +109,8 @@ class MKLDNNHandlerNoCachingT {
"Error: BWD_W_PD should be set when getting BWD grad of weights."
));
T
*
ptr
=
diff_weights
->
mutable_data
<
T
>
(
place_
,
bwd_w_pd_
->
diff_weights_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
bwd_w_pd_
->
diff_weights_desc
(),
ptr
);
return
this
->
AcquireMemoryFromPrimitive
(
bwd_w_pd_
->
diff_weights_desc
(),
ptr
);
}
// Buffer is allocated by oneDNN to store computation results
...
...
@@ -126,14 +123,13 @@ class MKLDNNHandlerNoCachingT {
}
protected:
// If your primitive descriptor requires attributes, pass them as a
// first argument and paramters to descriptor constructor in the following
// arguments. Otherwise, all arguments will be forwarded to descriptor
// constructor, including the first one.
template
<
typename
Arg
,
typename
...
Args
>
void
AcquireForwardPrimitiveDescriptor
(
Arg
&&
first_arg
,
Args
&&
...
args
)
{
CreateForwardPrimitiveDescriptor
(
first_arg
,
std
::
forward
<
Args
>
(
args
)...);
CreateForwardPrimitiveDescriptor
(
first_arg
,
std
::
forward
<
Args
>
(
args
)...);
}
// Using sfinae to specialise variadic function. Workaround for not having
...
...
@@ -161,9 +157,9 @@ class MKLDNNHandlerNoCachingT {
void
AcquireBackwardPrimitiveDescriptor
(
Args
&&
...
args
)
{
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL
(
fwd_pd_
,
platform
::
errors
::
Unavailable
(
"Get MKLDNN Forward primitive %s failed."
));
PADDLE_ENFORCE_NOT_NULL
(
fwd_pd_
,
platform
::
errors
::
Unavailable
(
"Get MKLDNN Forward primitive %s failed."
));
auto
bwd_desc
=
typename
TBackward
::
desc
(
std
::
forward
<
Args
>
(
args
)...);
bwd_pd_
=
std
::
make_shared
<
typename
TBackward
::
primitive_desc
>
(
bwd_desc
,
engine_
,
*
fwd_pd_
);
...
...
@@ -173,29 +169,29 @@ class MKLDNNHandlerNoCachingT {
void
AcquireBackwardWeightsPrimitiveDescriptor
(
Args
&&
...
args
)
{
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL
(
fwd_pd_
,
platform
::
errors
::
Unavailable
(
"Get MKLDNN Forward primitive %s failed."
));
auto
bwd_desc
=
typename
TBackward_params
::
desc
(
std
::
forward
<
Args
>
(
args
)...);
bwd_w_pd_
=
std
::
make_shared
<
typename
TBackward_params
::
primitive_desc
>
(
bwd_desc
,
engine_
,
*
fwd_pd_
);
PADDLE_ENFORCE_NOT_NULL
(
fwd_pd_
,
platform
::
errors
::
Unavailable
(
"Get MKLDNN Forward primitive %s failed."
));
auto
bwd_desc
=
typename
TBackward_params
::
desc
(
std
::
forward
<
Args
>
(
args
)...);
bwd_w_pd_
=
std
::
make_shared
<
typename
TBackward_params
::
primitive_desc
>
(
bwd_desc
,
engine_
,
*
fwd_pd_
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryFromPrimitive
(
mkldnn
::
memory
::
desc
md
,
void
*
ptr
)
{
return
std
::
make_shared
<
mkldnn
::
memory
>
(
md
,
engine_
,
ptr
);
return
std
::
make_shared
<
mkldnn
::
memory
>
(
md
,
engine_
,
ptr
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryFromPrimitive
(
mkldnn
::
memory
::
desc
md
)
{
return
std
::
make_shared
<
mkldnn
::
memory
>
(
md
,
engine_
);
return
std
::
make_shared
<
mkldnn
::
memory
>
(
md
,
engine_
);
}
void
AcquireReorder
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
user_memory_p
,
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
target_memory_p
)
{
auto
reorder_p
=
std
::
make_shared
<
mkldnn
::
reorder
>
(
*
user_memory_p
,
*
target_memory_p
);
auto
reorder_p
=
std
::
make_shared
<
mkldnn
::
reorder
>
(
*
user_memory_p
,
*
target_memory_p
);
auto
&
astream
=
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
...
...
@@ -212,33 +208,30 @@ class MKLDNNHandlerNoCachingT {
const
mkldnn
::
memory
::
desc
&
target_md
,
void
*
ptr
,
const
std
::
string
&
suffix
,
bool
is_persistent
=
false
,
std
::
function
<
std
::
shared_ptr
<
F
>
(
const
F
*
)
>
custom_reorder_func
=
{})
{
std
::
shared_ptr
<
mkldnn
::
memory
>
target_memory_p
;
if
(
custom_reorder_func
)
{
auto
reordered_data
=
custom_reorder_func
(
reinterpret_cast
<
const
F
*>
(
ptr
));
ptr
=
reinterpret_cast
<
void
*>
(
reordered_data
.
get
());
}
auto
user_memory_p
=
std
::
make_shared
<
dnnl
::
memory
>
(
user_md
,
engine_
,
ptr
);
if
(
user_md
!=
target_md
)
{
target_memory_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
target_md
,
engine_
);
auto
reorder_p
=
std
::
make_shared
<
dnnl
::
reorder
>
(
*
user_memory_p
,
*
target_memory_p
);
std
::
shared_ptr
<
mkldnn
::
memory
>
target_memory_p
;
if
(
custom_reorder_func
)
{
auto
reordered_data
=
custom_reorder_func
(
reinterpret_cast
<
const
F
*>
(
ptr
));
ptr
=
reinterpret_cast
<
void
*>
(
reordered_data
.
get
());
}
auto
user_memory_p
=
std
::
make_shared
<
dnnl
::
memory
>
(
user_md
,
engine_
,
ptr
);
if
(
user_md
!=
target_md
)
{
target_memory_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
target_md
,
engine_
);
auto
reorder_p
=
std
::
make_shared
<
dnnl
::
reorder
>
(
*
user_memory_p
,
*
target_memory_p
);
auto
&
astream
=
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
platform
::
RecordEvent
record_reorder
(
"int_reorder"
,
platform
::
EventRole
::
kUniqueOp
);
reorder_p
->
execute
(
astream
,
{{
MKLDNN_ARG_FROM
,
*
user_memory_p
},
{
MKLDNN_ARG_TO
,
*
target_memory_p
}});
astream
.
wait
();
}
else
{
target_memory_p
=
user_memory_p
;
}
auto
&
astream
=
platform
::
MKLDNNDeviceContext
::
tls
().
get_stream
();
platform
::
RecordEvent
record_reorder
(
"int_reorder"
,
platform
::
EventRole
::
kUniqueOp
);
reorder_p
->
execute
(
astream
,
{{
MKLDNN_ARG_FROM
,
*
user_memory_p
},
{
MKLDNN_ARG_TO
,
*
target_memory_p
}});
astream
.
wait
();
}
else
{
target_memory_p
=
user_memory_p
;
}
return
target_memory_p
;
}
mkldnn
::
engine
engine_
;
platform
::
Place
place_
;
std
::
shared_ptr
<
typename
TForward
::
primitive_desc
>
fwd_pd_
;
...
...
@@ -801,63 +794,64 @@ class MKLDNNHandler {
};
template
<
typename
T
>
class
BinaryMKLDNNHandler
:
public
platform
::
MKLDNNHandlerNoCachingT
<
T
,
dnnl
::
binary
>
{
class
BinaryMKLDNNHandler
:
public
platform
::
MKLDNNHandlerNoCachingT
<
T
,
dnnl
::
binary
>
{
public:
BinaryMKLDNNHandler
(
const
dnnl
::
algorithm
algo
,
const
int
axis
,
const
mkldnn
::
engine
engine
,
platform
::
Place
cpu_place
,
const
Tensor
*
x
,
const
Tensor
*
y
,
Tensor
*
z
,
float
scale_x
,
float
scale_y
,
float
scale_z
)
:
platform
::
MKLDNNHandlerNoCachingT
<
T
,
dnnl
::
binary
>
(
engine
,
cpu_place
)
{
PADDLE_ENFORCE_EQ
(
x
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"Wrong layout set for X tensor."
));
PADDLE_ENFORCE_NE
(
x
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for X tensor."
));
PADDLE_ENFORCE_EQ
(
y
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"Wrong layout set for Y tensor."
));
PADDLE_ENFORCE_NE
(
y
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Y tensor."
));
const
auto
src_x_tz
=
framework
::
vectorize
(
x
->
dims
());
const
auto
src_y_tz
=
framework
::
vectorize
(
y
->
dims
());
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
auto
rankdiff
=
x
->
dims
().
size
()
-
y
->
dims
().
size
();
const
auto
dst_tz
=
(
z
==
nullptr
)
?
(
rankdiff
>
0
?
src_x_tz
:
src_y_tz
)
:
framework
::
vectorize
(
z
->
dims
());
auto
src0_md
=
dnnl
::
memory
::
desc
(
src_x_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
x
->
format
());
auto
src1_md
=
dnnl
::
memory
::
desc
(
src_y_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
y
->
format
());
if
(
rankdiff
>
0
)
{
// Second input is of smaller rank than first
std
::
vector
<
int64_t
>
dims1_ex
(
rankdiff
,
1
);
dims1_ex
.
insert
(
next
(
dims1_ex
.
begin
(),
(
axis
==
-
1
?
rankdiff
:
axis
)),
src_y_tz
.
begin
(),
src_y_tz
.
end
());
src1_md
=
src1_md
.
reshape
(
dims1_ex
);
}
else
if
(
rankdiff
<
0
)
{
// First input is of smaller than second
std
::
vector
<
int64_t
>
dims0_ex
(
-
rankdiff
,
1
);
dims0_ex
.
insert
(
next
(
dims0_ex
.
begin
(),
(
axis
==
-
1
?
-
rankdiff
:
axis
)),
src_x_tz
.
begin
(),
src_x_tz
.
end
());
src0_md
=
src0_md
.
reshape
(
dims0_ex
);
}
const
auto
dst_md
=
memory
::
desc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
any
);
PADDLE_ENFORCE_EQ
(
x
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"Wrong layout set for X tensor."
));
PADDLE_ENFORCE_NE
(
x
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for X tensor."
));
PADDLE_ENFORCE_EQ
(
y
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"Wrong layout set for Y tensor."
));
PADDLE_ENFORCE_NE
(
y
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Y tensor."
));
const
auto
src_x_tz
=
framework
::
vectorize
(
x
->
dims
());
const
auto
src_y_tz
=
framework
::
vectorize
(
y
->
dims
());
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
auto
rankdiff
=
x
->
dims
().
size
()
-
y
->
dims
().
size
();
const
auto
dst_tz
=
(
z
==
nullptr
)
?
(
rankdiff
>
0
?
src_x_tz
:
src_y_tz
)
:
framework
::
vectorize
(
z
->
dims
());
auto
src0_md
=
dnnl
::
memory
::
desc
(
src_x_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
x
->
format
());
auto
src1_md
=
dnnl
::
memory
::
desc
(
src_y_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
y
->
format
());
if
(
rankdiff
>
0
)
{
// Second input is of smaller rank than first
std
::
vector
<
int64_t
>
dims1_ex
(
rankdiff
,
1
);
dims1_ex
.
insert
(
next
(
dims1_ex
.
begin
(),
(
axis
==
-
1
?
rankdiff
:
axis
)),
src_y_tz
.
begin
(),
src_y_tz
.
end
());
src1_md
=
src1_md
.
reshape
(
dims1_ex
);
}
else
if
(
rankdiff
<
0
)
{
// First input is of smaller than second
std
::
vector
<
int64_t
>
dims0_ex
(
-
rankdiff
,
1
);
dims0_ex
.
insert
(
next
(
dims0_ex
.
begin
(),
(
axis
==
-
1
?
-
rankdiff
:
axis
)),
src_x_tz
.
begin
(),
src_x_tz
.
end
());
src0_md
=
src0_md
.
reshape
(
dims0_ex
);
}
const
auto
dst_md
=
memory
::
desc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
any
);
auto
attributes
=
CreateAttributes
(
algo
,
scale_x
,
scale_y
,
scale_z
);
this
->
AcquireForwardPrimitiveDescriptor
(
attributes
,
algo
,
src0
_md
,
src1_md
,
dst_md
);
auto
attributes
=
CreateAttributes
(
algo
,
scale_x
,
scale_y
,
scale_z
);
this
->
AcquireForwardPrimitiveDescriptor
(
attributes
,
algo
,
src0_md
,
src1
_md
,
dst_md
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSecondSrcMemory
(
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
src1_desc
(),
to_void_cast
<
T
>
(
input_data
));
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
src1_desc
(),
to_void_cast
<
T
>
(
input_data
));
}
private:
...
...
@@ -981,51 +975,51 @@ class ReductionMKLDNNHandler
template
<
typename
T
>
class
ActivationMKLDNNHandler
:
public
MKLDNNHandlerNoCachingT
<
T
,
mkldnn
::
eltwise_forward
,
mkldnn
::
eltwise_backward
>
{
mkldnn
::
eltwise_backward
>
{
public:
ActivationMKLDNNHandler
(
mkldnn
::
algorithm
algorithm
,
const
framework
::
ExecutionContext
&
ctx
,
const
mkldnn
::
engine
engine
,
Place
cpu_place
,
const
framework
::
Tensor
*
in_x
)
:
platform
::
MKLDNNHandlerNoCachingT
<
T
,
mkldnn
::
eltwise_forward
,
mkldnn
::
eltwise_backward
>
(
engine
,
cpu_place
)
{
float
alpha
=
ctx
.
HasAttr
(
"alpha"
)
?
ctx
.
Attr
<
float
>
(
"alpha"
)
:
0
;
float
beta
=
ctx
.
HasAttr
(
"beta"
)
?
ctx
.
Attr
<
float
>
(
"bet
a"
)
:
0
;
// eltwise_linear means we are in scale op
if
(
algorithm
==
mkldnn
::
algorithm
::
eltwise_linear
)
{
bool
bias_after_scale
=
ctx
.
Attr
<
bool
>
(
"bias_after_scale"
);
auto
*
scale_tensor
=
ctx
.
Input
<
Tensor
>
(
"ScaleTensor
"
);
alpha
=
(
scale_tensor
==
nullptr
)
?
ctx
.
Attr
<
float
>
(
"scale"
)
:
(
float
)
*
(
scale_tensor
->
data
<
T
>
());
beta
=
ctx
.
Attr
<
float
>
(
"bias"
);
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if
(
!
bias_after_scale
)
beta
*=
alpha
;
}
else
{
// paddle uses beta but mkldnn uses alpha for swish
if
(
algorithm
==
mkldnn
::
algorithm
::
eltwise_swish
)
{
std
::
swap
(
alpha
,
beta
);
}
else
if
(
algorithm
==
dnnl
::
algorithm
::
eltwise_bounded_relu
)
{
alpha
=
ctx
.
Attr
<
float
>
(
"threshold"
);
}
mkldnn
::
eltwise_backward
>
(
engine
,
cpu_place
)
{
float
alpha
=
ctx
.
HasAttr
(
"alpha"
)
?
ctx
.
Attr
<
float
>
(
"alph
a"
)
:
0
;
float
beta
=
ctx
.
HasAttr
(
"beta"
)
?
ctx
.
Attr
<
float
>
(
"beta"
)
:
0
;
// eltwise_linear means we are in scale op
if
(
algorithm
==
mkldnn
::
algorithm
::
eltwise_linear
)
{
bool
bias_after_scale
=
ctx
.
Attr
<
bool
>
(
"bias_after_scale
"
);
auto
*
scale_tensor
=
ctx
.
Input
<
Tensor
>
(
"ScaleTensor"
);
alpha
=
(
scale_tensor
==
nullptr
)
?
ctx
.
Attr
<
float
>
(
"scale"
)
:
(
float
)
*
(
scale_tensor
->
data
<
T
>
()
);
beta
=
ctx
.
Attr
<
float
>
(
"bias"
);
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if
(
!
bias_after_scale
)
beta
*=
alpha
;
}
else
{
// paddle uses beta but mkldnn uses alpha for swish
if
(
algorithm
==
mkldnn
::
algorithm
::
eltwise_swish
)
{
std
::
swap
(
alpha
,
beta
);
}
else
if
(
algorithm
==
dnnl
::
algorithm
::
eltwise_bounded_relu
)
{
alpha
=
ctx
.
Attr
<
float
>
(
"threshold"
);
}
}
PADDLE_ENFORCE
(
in_x
->
dims
().
size
()
>=
1
||
in_x
->
dims
().
size
()
<=
6
,
platform
::
errors
::
Unimplemented
(
"Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is"
,
in_x
->
dims
().
size
()));
PADDLE_ENFORCE
(
in_x
->
dims
().
size
()
>=
1
||
in_x
->
dims
().
size
()
<=
6
,
platform
::
errors
::
Unimplemented
(
"Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is"
,
in_x
->
dims
().
size
()));
auto
src_tz
=
framework
::
vectorize
<
int64_t
>
(
in_x
->
dims
());
auto
src_fmt
=
src_tz
.
size
()
==
2
?
MKLDNNMemoryFormat
::
nc
:
in_x
->
format
();
auto
md
=
mkldnn
::
memory
::
desc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_fmt
);
auto
src_tz
=
framework
::
vectorize
<
int64_t
>
(
in_x
->
dims
());
auto
src_fmt
=
src_tz
.
size
()
==
2
?
MKLDNNMemoryFormat
::
nc
:
in_x
->
format
();
auto
md
=
mkldnn
::
memory
::
desc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_fmt
);
this
->
AcquireForwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
md
,
alpha
,
beta
);
this
->
AcquireForwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
md
,
alpha
,
beta
);
}
ActivationMKLDNNHandler
(
mkldnn
::
algorithm
algorithm
,
...
...
@@ -1033,40 +1027,42 @@ class ActivationMKLDNNHandler
const
mkldnn
::
engine
engine
,
Place
cpu_place
,
const
framework
::
Tensor
*
in_x
,
const
Tensor
*
out_grad
)
:
platform
::
MKLDNNHandlerNoCachingT
<
T
,
mkldnn
::
eltwise_forward
,
mkldnn
::
eltwise_backward
>
(
engine
,
cpu_place
)
{
float
alpha
=
ctx
.
HasAttr
(
"alpha"
)
?
ctx
.
Attr
<
float
>
(
"alpha"
)
:
0
;
float
beta
=
ctx
.
HasAttr
(
"beta"
)
?
ctx
.
Attr
<
float
>
(
"beta"
)
:
0
;
// paddle uses beta but mkldnn uses alpha for swish
if
(
algorithm
==
mkldnn
::
algorithm
::
eltwise_swish
)
{
std
::
swap
(
alpha
,
beta
);
}
else
if
(
algorithm
==
dnnl
::
algorithm
::
eltwise_bounded_relu
)
{
alpha
=
ctx
.
Attr
<
float
>
(
"threshold"
);
}
mkldnn
::
eltwise_backward
>
(
engine
,
cpu_place
)
{
float
alpha
=
ctx
.
HasAttr
(
"alpha"
)
?
ctx
.
Attr
<
float
>
(
"alpha"
)
:
0
;
float
beta
=
ctx
.
HasAttr
(
"beta"
)
?
ctx
.
Attr
<
float
>
(
"beta"
)
:
0
;
// paddle uses beta but mkldnn uses alpha for swish
if
(
algorithm
==
mkldnn
::
algorithm
::
eltwise_swish
)
{
std
::
swap
(
alpha
,
beta
);
}
else
if
(
algorithm
==
dnnl
::
algorithm
::
eltwise_bounded_relu
)
{
alpha
=
ctx
.
Attr
<
float
>
(
"threshold"
);
}
auto
diff_dst_tz
=
framework
::
vectorize
<
int64_t
>
(
out_grad
->
dims
());
auto
diff_dst_tz
=
framework
::
vectorize
<
int64_t
>
(
out_grad
->
dims
());
auto
src_fmt
=
diff_dst_tz
.
size
()
==
2
?
MKLDNNMemoryFormat
::
nc
:
in_x
->
format
();
auto
diff_fmt
=
diff_dst_tz
.
size
()
==
2
?
MKLDNNMemoryFormat
::
nc
:
out_grad
->
format
();
auto
src_fmt
=
diff_dst_tz
.
size
()
==
2
?
MKLDNNMemoryFormat
::
nc
:
in_x
->
format
();
auto
diff_fmt
=
diff_dst_tz
.
size
()
==
2
?
MKLDNNMemoryFormat
::
nc
:
out_grad
->
format
();
auto
dims
=
framework
::
vectorize
(
in_x
->
dims
());
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
auto
src_md
=
platform
::
MKLDNNMemDesc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_fmt
);
auto
dims
=
framework
::
vectorize
(
in_x
->
dims
());
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
auto
src_md
=
platform
::
MKLDNNMemDesc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_fmt
);
this
->
AcquireForwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
src_md
,
alpha
,
beta
);
this
->
AcquireBackwardPrimitiveDescriptor
(
algorithm
,
diff_dst_md
,
src_md
,
alpha
,
beta
);
this
->
AcquireForwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
src_md
,
alpha
,
beta
);
this
->
AcquireBackwardPrimitiveDescriptor
(
algorithm
,
diff_dst_md
,
src_md
,
alpha
,
beta
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireBackwardSrcMemory
(
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
this
->
bwd_pd_
->
src_desc
(),
to_void_cast
<
T
>
(
input_data
));
return
this
->
AcquireMemoryFromPrimitive
(
this
->
bwd_pd_
->
src_desc
(),
to_void_cast
<
T
>
(
input_data
));
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录