Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e7724a2c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e7724a2c
编写于
7月 17, 2020
作者:
A
Adam
提交者:
GitHub
7月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor of conv fp32 oneDNN operator (#25137) (#25572)
上级
9bf70039
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
467 addition
and
266 deletion
+467
-266
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+345
-261
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+122
-5
未找到文件。
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
e7724a2c
...
...
@@ -26,42 +26,24 @@ using mkldnn::memory;
using
mkldnn
::
primitive
;
using
mkldnn
::
reorder
;
using
mkldnn
::
stream
;
using
platform
::
to_void_cast
;
using
platform
::
GetMKLDNNFormat
;
using
platform
::
to_void_cast
;
inline
void
GetWeightsTz
(
std
::
vector
<
int64_t
>&
weights_tz
,
// NOLINT
int
groups
,
bool
is_conv3d
)
{
const
int
groups
)
{
if
(
groups
>
1
)
{
if
(
is_conv3d
)
{
int
output
=
weights_tz
[
0
];
int
input
=
weights_tz
[
1
];
int
dimension
=
weights_tz
[
2
];
int
height
=
weights_tz
[
3
];
int
width
=
weights_tz
[
4
];
weights_tz
.
resize
(
6
);
weights_tz
[
0
]
=
groups
;
weights_tz
[
1
]
=
output
/
groups
;
weights_tz
[
2
]
=
input
;
weights_tz
[
3
]
=
dimension
;
weights_tz
[
4
]
=
height
;
weights_tz
[
5
]
=
width
;
}
else
{
int
output
=
weights_tz
[
0
];
int
input
=
weights_tz
[
1
];
int
height
=
weights_tz
[
2
];
int
width
=
weights_tz
[
3
];
weights_tz
.
resize
(
5
);
weights_tz
[
0
]
=
groups
;
weights_tz
[
1
]
=
output
/
groups
;
weights_tz
[
2
]
=
input
;
weights_tz
[
3
]
=
height
;
weights_tz
[
4
]
=
width
;
}
// if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
// else [o, i, h, w] -> [g, o/g, i, h, w]
weights_tz
.
push_back
(
0
);
std
::
rotate
(
weights_tz
.
begin
(),
weights_tz
.
end
()
-
1
,
weights_tz
.
end
());
weights_tz
[
0
]
=
groups
;
weights_tz
[
1
]
=
weights_tz
[
1
]
/
groups
;
}
}
inline
MKLDNNMemoryFormat
GetWeightsFormat
(
MKLDNNMemoryFormat
format
,
int
groups
,
bool
is_conv3d
)
{
inline
MKLDNNMemoryFormat
GetWeightsFormat
(
const
MKLDNNMemoryFormat
format
,
const
int
groups
,
const
bool
is_conv3d
)
{
if
(
is_conv3d
)
{
return
(
groups
==
1
)
?
format
:
MKLDNNMemoryFormat
::
goidhw
;
}
else
{
...
...
@@ -90,284 +72,386 @@ static mkldnn::memory::data_type GetDstType(bool is_int8,
return
dst_dt
;
}
template
<
typename
T
,
typename
K
>
class
ConvMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
template
<
typename
T
>
class
ConvMKLDNNHandlerT
:
public
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
convolution_forward
>
{
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
bool
is_INT8
=
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
;
if
(
!
is_INT8
)
{
ComputeFP32
(
ctx
);
}
else
{
std
::
string
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
bool
force_fp32_output
=
ctx
.
Attr
<
bool
>
(
"force_fp32_output"
);
auto
residual_param
=
ctx
.
Input
<
Tensor
>
(
"ResidualData"
);
auto
dst_dt
=
GetDstType
(
true
,
force_fp32_output
,
fuse_activation
,
fuse_residual_conn
,
residual_param
);
if
(
dst_dt
==
mkldnn
::
memory
::
data_type
::
f32
)
{
ComputeINT8
<
float
>
(
ctx
);
}
else
if
(
dst_dt
==
mkldnn
::
memory
::
data_type
::
u8
)
{
ComputeINT8
<
uint8_t
>
(
ctx
);
}
else
if
(
dst_dt
==
mkldnn
::
memory
::
data_type
::
s8
)
{
ComputeINT8
<
int8_t
>
(
ctx
);
}
}
}
ConvMKLDNNHandlerT
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
mkldnn
::
engine
mkldnn_engine
,
platform
::
Place
cpu_place
,
const
Tensor
*
input
,
const
Tensor
*
filter
,
const
Tensor
*
bias
,
Tensor
*
output
,
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
convolution_forward
>
(
dev_ctx
,
mkldnn_engine
,
cpu_place
,
platform
::
CreateKey
(
framework
::
vectorize
(
input
->
dims
()),
unique_name
))
{
if
(
!
this
->
isCached
())
{
PADDLE_ENFORCE_EQ
(
input
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"The input tensor's layout should be %d, but got %d."
,
DataLayout
::
kMKLDNN
,
input
->
layout
()));
PADDLE_ENFORCE_NE
(
input
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Input tensor"
));
void
ComputeFP32
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
{
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
PADDLE_ENFORCE_EQ
(
filter
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"The Filter tensor's layout should be %d, but got %d."
,
DataLayout
::
kMKLDNN
,
filter
->
layout
()));
PADDLE_ENFORCE_NE
(
filter
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Filter tensor"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
paddle
::
platform
::
MKLDNNDeviceContext
>();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
PADDLE_ENFORCE_GE
(
input
->
dims
().
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d ."
,
input
->
dims
().
size
()));
PADDLE_ENFORCE_LE
(
input
->
dims
().
size
(),
5
,
platform
::
errors
::
InvalidArgument
(
"Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d ."
,
input
->
dims
().
size
()));
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
*
filter
=
ctx
.
Input
<
Tensor
>
(
"Filter"
);
auto
*
bias
=
ctx
.
HasInput
(
"Bias"
)
?
ctx
.
Input
<
Tensor
>
(
"Bias"
)
:
nullptr
;
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
PADDLE_ENFORCE_GE
(
filter
->
dims
().
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"Filter must be with 4 or 5 dimensions, i.e. OIHW or "
"OIDHW, but got dimension = %d ."
,
filter
->
dims
().
size
()));
PADDLE_ENFORCE_LE
(
filter
->
dims
().
size
(),
5
,
platform
::
errors
::
InvalidArgument
(
"Filter must be with 4 or 5 dimensions, i.e. OIHW or "
"OIDHW, but got dimension = %d ."
,
filter
->
dims
().
size
()));
PADDLE_ENFORCE_EQ
(
input
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"The input tensor's layout should be %d, but got %d."
,
DataLayout
::
kMKLDNN
,
input
->
layout
()));
PADDLE_ENFORCE_NE
(
input
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Input tensor"
));
if
(
bias
)
{
PADDLE_ENFORCE_EQ
(
bias
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"The Bias tensor's layout should be %d, but got %d."
,
DataLayout
::
kMKLDNN
,
bias
->
layout
()));
PADDLE_ENFORCE_NE
(
bias
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Got wrong format for Bias tensor."
));
PADDLE_ENFORCE_EQ
(
filter
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"The Filter tensor's layout should be %d, but got %d."
,
DataLayout
::
kMKLDNN
,
filter
->
layout
()));
PADDLE_ENFORCE_NE
(
filter
->
format
(),
MKLDNNMemoryFormat
::
undef
,
"Wrong format set for Filter tensor"
);
PADDLE_ENFORCE_EQ
(
bias
->
dims
().
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Bias must only have 1 dimension, "
"i.e. X, but got dimension = %d ."
,
bias
->
dims
().
size
()));
}
PADDLE_ENFORCE_GE
(
input
->
dims
().
size
(),
4
,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW"
);
PADDLE_ENFORCE_LE
(
input
->
dims
().
size
(),
5
,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW"
);
const
std
::
string
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
const
float
fuse_alpha
=
ctx
.
Attr
<
float
>
(
"fuse_alpha"
);
const
float
fuse_beta
=
ctx
.
Attr
<
float
>
(
"fuse_beta"
);
const
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
const
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
const
std
::
string
padding_algorithm
=
ctx
.
Attr
<
std
::
string
>
(
"padding_algorithm"
);
PADDLE_ENFORCE_GE
(
filter
->
dims
().
size
(),
4
,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"
);
PADDLE_ENFORCE_LE
(
filter
->
dims
().
size
(),
5
,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"
);
const
auto
input_dims
=
input
->
dims
();
const
auto
data_dims
=
framework
::
slice_ddim
(
input_dims
,
2
,
input_dims
.
size
()
);
const
auto
filter_dims
=
filter
->
dims
();
const
auto
filter_data_dims
=
framework
::
slice_ddim
(
filter_dims
,
2
,
filter_dims
.
size
()
);
if
(
bias
)
{
PADDLE_ENFORCE_EQ
(
bias
->
layout
(),
DataLayout
::
kMKLDNN
,
platform
::
errors
::
InvalidArgument
(
"The Bias tensor's layout should be %d, but got %d."
,
DataLayout
::
kMKLDNN
,
bias
->
layout
()));
PADDLE_ENFORCE_NE
(
bias
->
format
(),
MKLDNNMemoryFormat
::
undef
,
"Wrong format set for Bias tensor"
);
const
auto
ksize
=
framework
::
vectorize
(
filter_data_dims
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
PADDLE_ENFORCE_EQ
(
bias
->
dims
().
size
(),
1
,
"Bias must only have 1 dimension, i.e. X"
);
}
auto
strides_temp
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int64_t
>
strides
(
begin
(
strides_temp
),
end
(
strides_temp
));
std
::
vector
<
int
>
strides_temp
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"stride
s"
);
std
::
vector
<
int64_t
>
strides
(
begin
(
strides_temp
),
end
(
stride
s_temp
));
auto
paddings_temp
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"padding
s"
);
std
::
vector
<
int64_t
>
paddings
(
begin
(
paddings_temp
),
end
(
padding
s_temp
));
std
::
vector
<
int
>
paddings_temp
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int64_t
>
paddings
(
begin
(
paddings_temp
),
end
(
paddings_temp
));
auto
dilations_temp
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
std
::
vector
<
int64_t
>
dilations
(
begin
(
dilations_temp
),
end
(
dilations_temp
));
std
::
vector
<
int
>
dilations_temp
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
std
::
vector
<
int64_t
>
dilations
(
begin
(
dilations_temp
),
end
(
dilations_temp
));
UpdatePaddingAndDilation
(
&
paddings
,
&
dilations
,
padding_algorithm
,
data_dims
,
strides
,
ksize
);
const
bool
is_conv3d
=
strides
.
size
()
==
3U
;
std
::
string
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
float
fuse_alpha
=
ctx
.
Attr
<
float
>
(
"fuse_alpha"
);
float
fuse_beta
=
ctx
.
Attr
<
float
>
(
"fuse_beta"
);
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
std
::
string
padding_algorithm
=
ctx
.
Attr
<
std
::
string
>
(
"padding_algorithm"
);
bool
is_conv3d
=
strides
.
size
()
==
3U
;
PADDLE_ENFORCE_EQ
(
is_conv3d
?
dilations
.
size
()
==
3
&&
dilations
[
0
]
==
1
&&
dilations
[
1
]
==
1
&&
dilations
[
2
]
==
1
:
dilations
.
size
()
==
2
&&
dilations
[
0
]
==
1
&&
dilations
[
1
]
==
1
,
true
,
platform
::
errors
::
Unimplemented
(
"Dilation in oneDNN convolution is not implemented yet"
))
;
auto
input_dims
=
input
->
dims
();
auto
data_dims
=
framework
::
slice_ddim
(
input_dims
,
2
,
input_dims
.
size
());
auto
filter_dims
=
filter
->
dims
();
auto
filter_data_dims
=
framework
::
slice_ddim
(
filter_dims
,
2
,
filter_dims
.
size
());
const
auto
src_tz
=
paddle
::
framework
::
vectorize
(
input
->
dims
());
auto
ksize
=
framework
::
vectorize
(
filter_data_dims
);
auto
weights_tz
=
paddle
::
framework
::
vectorize
(
filter
->
dims
());
GetWeightsTz
(
weights_tz
,
groups
);
UpdatePaddingAndDilation
(
&
paddings
,
&
dilations
,
padding_algorithm
,
data_dims
,
strides
,
ksize
);
const
auto
dst_tz
=
paddle
::
framework
::
vectorize
(
output
->
dims
());
std
::
vector
<
primitive
>
pipeline
;
const
mkldnn
::
memory
::
dims
stride_dims
=
strides
;
const
auto
mkldnn_paddings
=
platform
::
ToMkldnnPadding
(
paddings
);
PADDLE_ENFORCE
(
is_conv3d
?
dilations
.
size
()
==
3
&&
dilations
[
0
]
==
1
&&
dilations
[
1
]
==
1
&&
dilations
[
2
]
==
1
:
dilations
.
size
()
==
2
&&
dilations
[
0
]
==
1
&&
dilations
[
1
]
==
1
,
"dilation in convolution is not implemented yet"
);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
// TODO(jczaja): This is workaround to make grad op UT's numerical
// gradient computation proper as this op is called directly without
// fetch op following it , so numercial grad is computed (in python)
// using block formats which will give wrong results
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
auto
chosen_memory_format
=
is_test
?
MKLDNNMemoryFormat
::
any
:
platform
::
data_format_to_memory_format
(
data_format
);
// Check the format for user's special output
if
(
chosen_memory_format
!=
MKLDNNMemoryFormat
::
any
)
{
if
(
is_conv3d
)
{
chosen_memory_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
chosen_memory_format
);
}
}
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
filter_data
=
filter
->
data
<
T
>
();
const
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
chosen_memory_format
);
const
auto
weights_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
any
);
const
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
chosen_memory_format
);
auto
src_tz
=
paddle
::
framework
::
vectorize
(
input
->
dims
());
auto
weights_tz
=
paddle
::
framework
::
vectorize
(
filter
->
dims
());
int
g
=
std
::
max
(
groups
,
1
);
const
auto
fwd_prop_kind
=
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
;
GetWeightsTz
(
weights_tz
,
g
,
is_conv3d
);
const
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
fuse_activation
,
fuse_alpha
,
fuse_beta
,
fuse_residual_conn
);
auto
dst_tz
=
paddle
::
framework
::
vectorize
(
output
->
dims
());
if
(
bias
)
{
auto
bias_tz
=
framework
::
vectorize
(
bias
->
dims
());
auto
bias_md
=
platform
::
MKLDNNMemDesc
(
bias_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
x
);
this
->
AcquireForwardPrimitiveDescriptor
(
conv_attr
,
fwd_prop_kind
,
dnnl
::
algorithm
::
convolution_direct
,
src_md
,
weights_md
,
bias_md
,
dst_md
,
stride_dims
,
mkldnn_paddings
[
0
],
mkldnn_paddings
[
1
]);
}
else
{
this
->
AcquireForwardPrimitiveDescriptor
(
conv_attr
,
fwd_prop_kind
,
dnnl
::
algorithm
::
convolution_direct
,
src_md
,
weights_md
,
dst_md
,
stride_dims
,
mkldnn_paddings
[
0
],
mkldnn_paddings
[
1
]);
}
}
}
// Get unique name for storing MKLDNN primitives
const
std
::
string
key
=
platform
::
CreateKey
(
src_tz
,
ctx
.
InputName
(
"Input"
)
+
ctx
.
InputName
(
"Filter"
));
mkldnn
::
primitive_attr
CreatePostOps
(
std
::
string
fuse_activation
,
float
fuse_alpha
,
float
fuse_beta
,
bool
fuse_residual_conn
,
const
std
::
vector
<
float
>
output_shift_scale
=
{},
float
sum_scale
=
1.0
f
)
{
mkldnn
::
primitive_attr
conv_attr
;
mkldnn
::
post_ops
post_operations
;
if
(
output_shift_scale
.
size
()
>
0
)
{
int
mask
=
output_shift_scale
.
size
()
>
1
?
1
<<
1
:
0
;
conv_attr
.
set_output_scales
(
mask
,
output_shift_scale
);
}
auto
src_format
=
input
->
format
();
MKLDNNMemoryFormat
weights_format
=
GetWeightsFormat
(
filter
->
format
(),
g
,
is_conv3d
);
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if
(
fuse_residual_conn
)
{
post_operations
.
append_sum
(
sum_scale
);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if
(
fuse_activation
==
"relu"
||
fuse_activation
==
"leaky_relu"
)
{
constexpr
float
scale
=
1.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_relu
,
fuse_alpha
,
fuse_beta
);
}
else
if
(
fuse_activation
==
"relu6"
)
{
constexpr
float
scale
=
1.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_bounded_relu
,
fuse_alpha
,
fuse_beta
);
}
else
if
(
fuse_activation
==
"swish"
)
{
constexpr
float
scale
=
1.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_swish
,
fuse_alpha
,
fuse_beta
);
}
conv_attr
.
set_post_ops
(
post_operations
);
return
conv_attr
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemoryWithReorder
(
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
auto
user_weights_md
=
platform
::
MKLDNNMemDesc
(
{
weights_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
weights_format
);
framework
::
vectorize
(
input
->
dims
()),
platform
::
MKLDNNGetDataType
<
T
>
(),
input
->
format
());
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
// TODO(jczaja): This is workaround to make grad op UT's numerical
// gradient computation proper as this op is called directly without
// fetch op following it , so numercial grad is computed (in python)
// using block formats which will give wrong results
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
auto
chosen_memory_format
=
is_test
?
MKLDNNMemoryFormat
::
any
:
platform
::
data_format_to_memory_format
(
data_format
);
return
this
->
AcquireMemoryWithReorder
(
user_src_md
,
this
->
fwd_pd_
->
src_desc
(),
to_void_cast
<
T
>
(
input_data
),
"@src_mem_p"
);
}
weights_format
=
MKLDNNMemoryFormat
::
any
;
// Check the format for user's special output
if
(
chosen_memory_format
!=
MKLDNNMemoryFormat
::
any
)
{
if
(
is_conv3d
)
{
chosen_memory_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
chosen_memory_format
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWeightsMemoryWithReorder
(
const
framework
::
Tensor
*
filter
,
const
int
groups
,
const
bool
is_conv3d
,
const
bool
is_test
)
{
// This is workaround to make execution faster, delete
// if statement after including md inside Tensor
auto
weights_mem_p
=
this
->
AcquireMemory
(
"@weights_mem_p_target"
);
if
(
is_test
&&
weights_mem_p
)
{
return
weights_mem_p
;
}
else
{
const
T
*
filter_data
=
filter
->
data
<
T
>
();
auto
weights_tz
=
framework
::
vectorize
(
filter
->
dims
());
GetWeightsTz
(
weights_tz
,
groups
);
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
GetWeightsFormat
(
filter
->
format
(),
groups
,
is_conv3d
));
return
this
->
AcquireMemoryWithReorder
(
user_src_md
,
this
->
fwd_pd_
->
weights_desc
(),
to_void_cast
<
T
>
(
filter_data
),
"@weights_mem_p"
,
is_test
);
}
}
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
chosen_memory_format
);
auto
weights_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
weights_format
);
std
::
vector
<
int64_t
>
bias_tz
;
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
chosen_memory_format
);
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireBiasMemoryWithReorder
(
const
framework
::
Tensor
*
bias
,
const
bool
is_test
)
{
const
T
*
bias_data
=
bias
->
data
<
T
>
();
auto
user_bias_md
=
platform
::
MKLDNNMemDesc
(
framework
::
vectorize
(
bias
->
dims
()),
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
x
);
return
this
->
AcquireMemoryWithReorder
(
user_bias_md
,
this
->
fwd_pd_
->
bias_desc
(),
to_void_cast
<
T
>
(
bias_data
),
"@bias_mem_p"
,
is_test
);
}
platform
::
ConvMKLDNNHandler
handler
(
dev_ctx
,
mkldnn_engine
,
key
);
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireResidualMemory
(
const
framework
::
Tensor
*
residual_param
)
{
const
T
*
residual_data
=
residual_param
->
data
<
T
>
();
auto
user_residual_md
=
platform
::
MKLDNNMemDesc
(
framework
::
vectorize
(
residual_param
->
dims
()),
framework
::
ToMKLDNNDataType
(
residual_param
->
type
()),
residual_param
->
format
());
return
this
->
AcquireMemoryFromPrimitive
(
user_residual_md
,
to_void_cast
<
T
>
(
residual_data
),
"@user_residual_data_mem_p"
);
}
// create a conv primitive descriptor and save it for usage in backward
std
::
shared_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>
conv_pd
;
auto
fwd_prop_kind
=
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
;
if
(
bias
)
{
bias_tz
=
paddle
::
framework
::
vectorize
(
bias
->
dims
());
auto
bias_md
=
platform
::
MKLDNNMemDesc
(
bias_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
x
);
conv_pd
=
handler
.
AcquireConvolutionPrimitiveDescriptor
(
src_md
,
weights_md
,
bias_md
,
dst_md
,
strides
,
paddings
,
mkldnn_engine
,
fuse_activation
,
fuse_alpha
,
fuse_beta
,
fuse_residual_conn
,
fwd_prop_kind
);
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemoryWithResidual
(
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
residual_param
)
{
std
::
shared_ptr
<
dnnl
::
memory
>
dst_memory_p
;
if
(
residual_param
->
format
()
!=
platform
::
GetMKLDNNFormat
(
this
->
fwd_pd_
->
dst_desc
()))
{
auto
residual_memory_p
=
this
->
AcquireResidualMemory
(
residual_param
);
dst_memory_p
=
this
->
AcquireDstMemory
(
output
);
this
->
AcquireReorder
(
residual_memory_p
,
dst_memory_p
,
"@residual_dst"
);
}
else
{
conv_pd
=
handler
.
AcquireConvolutionPrimitiveDescriptor
(
src_md
,
weights_md
,
boost
::
none
,
dst_md
,
strides
,
paddings
,
mkldnn_engine
,
fuse_activation
,
fuse_alpha
,
fuse_beta
,
fuse_residual_conn
,
fwd_prop_kind
);
// Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures
// (https://github.com/PaddlePaddle/Paddle/issues/22964)
output
->
ShareDataWith
(
*
residual_param
);
dst_memory_p
=
this
->
AcquireDstMemory
(
output
);
}
return
dst_memory_p
;
}
};
// create mkldnn memory from input tensors (data/weights)
auto
user_src_memory_p
=
handler
.
AcquireSrcMemory
(
user_src_md
,
to_void_cast
<
T
>
(
input_data
));
auto
user_weights_memory_p
=
handler
.
AcquireWeightsMemory
(
user_weights_md
,
to_void_cast
<
T
>
(
filter_data
));
template
<
typename
T
,
typename
K
>
class
ConvMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"Operator DNNL Conv must use CPUPlace"
));
bool
is_INT8
=
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
;
if
(
!
is_INT8
)
{
ComputeFP32
(
ctx
);
}
else
{
std
::
string
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
bool
force_fp32_output
=
ctx
.
Attr
<
bool
>
(
"force_fp32_output"
);
auto
residual_param
=
ctx
.
Input
<
Tensor
>
(
"ResidualData"
);
auto
dst_dt
=
GetDstType
(
true
,
force_fp32_output
,
fuse_activation
,
fuse_residual_conn
,
residual_param
);
if
(
dst_dt
==
mkldnn
::
memory
::
data_type
::
f32
)
{
ComputeINT8
<
float
>
(
ctx
);
}
else
if
(
dst_dt
==
mkldnn
::
memory
::
data_type
::
u8
)
{
ComputeINT8
<
uint8_t
>
(
ctx
);
}
else
if
(
dst_dt
==
mkldnn
::
memory
::
data_type
::
s8
)
{
ComputeINT8
<
int8_t
>
(
ctx
);
}
}
}
// create reorder primitive if the input format is not the preferred one
auto
src_memory_p
=
handler
.
AcquireSrcMemoryFromPrimitive
(
user_src_memory_p
,
pipeline
);
auto
weights_memory_p
=
handler
.
AcquireWeightsMemoryFromPrimitive
(
user_weights_memory_p
,
pipeline
,
is_test
);
void
ComputeFP32
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
paddle
::
platform
::
MKLDNNDeviceContext
>();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
user_residual_memory_p
;
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
is_conv3d
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
).
size
()
==
3U
;
const
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
if
(
fuse_residual_conn
)
{
auto
residual_param
=
ctx
.
Input
<
Tensor
>
(
"ResidualData"
);
auto
residual_param_data
=
residual_param
->
data
<
T
>
();
PADDLE_ENFORCE_NE
(
residual_param_data
,
nullptr
,
"Provide data if you want MKLDNN conv+elementwise_add fusion"
);
PADDLE_ENFORCE_EQ
(
output
->
dims
(),
residual_param
->
dims
(),
"Output and elementwise parameter need to have the "
"same dimension sizes"
);
if
(
residual_param
->
format
()
!=
handler
.
GetDstFormat
())
{
auto
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
handler
.
GetDstMemorySize
());
auto
residual_data_tz
=
paddle
::
framework
::
vectorize
(
residual_param
->
dims
());
auto
residual_data_type
=
paddle
::
framework
::
ToMKLDNNDataType
(
residual_param
->
type
());
const
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
const
auto
*
filter
=
ctx
.
Input
<
Tensor
>
(
"Filter"
);
const
auto
*
bias
=
ctx
.
HasInput
(
"Bias"
)
?
ctx
.
Input
<
Tensor
>
(
"Bias"
)
:
nullptr
;
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
auto
user_residual_md
=
platform
::
MKLDNNMemDesc
(
residual_data_tz
,
residual_data_type
,
residual_param
->
format
());
user_residual_memory_p
=
handler
.
AcquireResidualDataMemory
(
user_residual_md
,
to_void_cast
<
T
>
(
residual_param_data
));
ConvMKLDNNHandlerT
<
T
>
handler
(
ctx
,
dev_ctx
,
mkldnn_engine
,
ctx
.
GetPlace
(),
input
,
filter
,
bias
,
output
,
ctx
.
InputName
(
"Input"
)
+
ctx
.
InputName
(
"Filter"
));
dst_memory_p
=
handler
.
AcquireDstMemoryFromResidualDataMemory
(
user_residual_memory_p
,
to_void_cast
<
T
>
(
output_data
),
pipeline
);
}
else
{
// Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures
// (https://github.com/PaddlePaddle/Paddle/issues/22964)
output
->
ShareDataWith
(
*
residual_param
);
auto
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dst_memory_p
=
handler
.
AcquireDstMemoryFromPrimitive
(
to_void_cast
<
T
>
(
output_data
));
}
}
else
{
auto
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
handler
.
GetDstMemorySize
());
dst_memory_p
=
handler
.
AcquireDstMemoryFromPrimitive
(
to_void_cast
<
T
>
(
output_data
));
}
auto
src_memory_p
=
handler
.
AcquireSrcMemoryWithReorder
(
input
);
auto
conv_p
=
handler
.
AcquireConvolution
();
auto
weights_memory_p
=
handler
.
AcquireWeightsMemoryWithReorder
(
filter
,
ctx
.
Attr
<
int
>
(
"groups"
),
is_conv3d
,
is_test
);
mkldnn
::
stream
astream
(
mkldnn_engine
);
if
(
bias
)
{
const
T
*
bias_data
=
bias
->
data
<
T
>
();
auto
user_bias_md
=
platform
::
MKLDNNMemDesc
(
{
bias_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
x
);
auto
user_bias_memory_p
=
handler
.
AcquireBiasMemory
(
user_bias_md
,
to_void_cast
<
T
>
(
bias_data
));
std
::
shared_ptr
<
dnnl
::
memory
>
dst_memory_p
;
if
(
fuse_residual_conn
)
{
auto
*
residual_param
=
ctx
.
Input
<
Tensor
>
(
"ResidualData"
);
dst_memory_p
=
handler
.
AcquireDstMemoryWithResidual
(
output
,
residual_param
);
}
else
{
dst_memory_p
=
handler
.
AcquireDstMemory
(
output
);
}
auto
bias_memory_p
=
handler
.
AcquireBiasMemoryFromPrimitive
(
user_bias_memory_p
,
pipeline
);
auto
conv_p
=
handler
.
AcquireForwardPrimitive
();
conv_p
->
execute
(
astream
,
{{
MKLDNN_ARG_SRC
,
*
src_memory_p
},
{
MKLDNN_ARG_WEIGHTS
,
*
weights
_memory_p
},
{
MKLDNN_ARG_BIAS
,
*
bia
s_memory_p
},
{
MKLDNN_ARG_DST
,
*
dst_memory_p
}})
;
std
::
unordered_map
<
int
,
dnnl
::
memory
>
args
=
{
{
MKLDNN_ARG_SRC
,
*
src
_memory_p
},
{
MKLDNN_ARG_WEIGHTS
,
*
weight
s_memory_p
},
{
MKLDNN_ARG_DST
,
*
dst_memory_p
}}
;
}
else
{
conv_p
->
execute
(
astream
,
{{
MKLDNN_ARG_SRC
,
*
src_memory_p
},
{
MKLDNN_ARG_WEIGHTS
,
*
weights_memory_p
},
{
MKLDNN_ARG_DST
,
*
dst_memory_p
}});
if
(
bias
)
{
auto
bias_memory_p
=
handler
.
AcquireBiasMemoryWithReorder
(
bias
,
is_test
);
args
.
insert
({
MKLDNN_ARG_BIAS
,
*
bias_memory_p
});
}
mkldnn
::
stream
astream
(
mkldnn_engine
);
conv_p
->
execute
(
astream
,
args
);
astream
.
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory_p
));
}
template
<
typename
T_out
>
void
ComputeINT8
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
{
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
...
...
@@ -516,7 +600,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
weights_tz
=
paddle
::
framework
::
vectorize
(
filter
->
dims
());
int
g
=
std
::
max
(
groups
,
1
);
GetWeightsTz
(
weights_tz
,
g
,
is_conv3d
);
GetWeightsTz
(
weights_tz
,
g
);
auto
dst_tz
=
paddle
::
framework
::
vectorize
(
output
->
dims
());
PADDLE_ENFORCE_EQ
(
...
...
@@ -562,9 +646,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
((
g
)
==
1
)
?
MKLDNNMemoryFormat
::
oihw
:
MKLDNNMemoryFormat
::
goihw
);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
auto
chosen_memory_format
=
MKLDNNMemoryFormat
::
any
;
std
::
vector
<
int64_t
>
bias_tz
;
...
...
@@ -823,7 +907,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
weights_tz
=
paddle
::
framework
::
vectorize
(
filter
->
dims
());
int
g
=
std
::
max
(
groups
,
1
);
GetWeightsTz
(
weights_tz
,
g
,
is_conv3d
);
GetWeightsTz
(
weights_tz
,
g
);
auto
dst_tz
=
paddle
::
framework
::
vectorize
(
output_grad
->
dims
());
auto
src_format
=
input
->
format
();
...
...
@@ -836,7 +920,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
std
::
string
key
=
platform
::
CreateKey
(
src_tz
,
ctx
.
InputName
(
"Input"
)
+
ctx
.
InputName
(
"Filter"
));
const
std
::
string
key_conv_pd
=
key
+
"@
conv
_pd"
;
const
std
::
string
key_conv_pd
=
key
+
"@
forward
_pd"
;
std
::
vector
<
primitive
>
pipeline
;
// Create user memory descriptors
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
e7724a2c
...
...
@@ -108,8 +108,20 @@ class MKLDNNHandlerT {
}
protected:
template
<
typename
...
Args
>
void
AcquireForwardPrimitiveDescriptor
(
Args
&&
...
args
)
{
bool
isCached
()
{
const
std
::
string
key_pd
=
key_common_
+
"@forward_pd"
;
fwd_pd_
=
std
::
static_pointer_cast
<
typename
TForward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_pd
));
const
std
::
string
key_p
=
key_
+
"@forward_p"
;
return
(
dev_ctx_
.
GetBlob
(
key_p
)
!=
nullptr
);
}
// If your primitive descriptor requires attributes, pass them as a
// first argument and paramters to descriptor constructor in the following
// arguments. Otherwise, all arguments will be forwarded to descriptor
// constructor, including the first one.
template
<
typename
Arg
,
typename
...
Args
>
void
AcquireForwardPrimitiveDescriptor
(
Arg
&&
first_arg
,
Args
&&
...
args
)
{
// Forward PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
...
...
@@ -123,14 +135,34 @@ class MKLDNNHandlerT {
fwd_pd_
=
std
::
static_pointer_cast
<
typename
TForward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_pd
));
if
(
fwd_pd_
==
nullptr
)
{
auto
fwd_desc
=
typename
TForward
::
desc
(
std
::
forward
<
Args
>
(
args
)...);
fwd_pd_
=
std
::
make_shared
<
typename
TForward
::
primitive_desc
>
(
fwd_desc
,
engine_
);
CreateForwardPrimitiveDescriptor
(
first_arg
,
std
::
forward
<
Args
>
(
args
)...);
dev_ctx_
.
SetBlob
(
key_pd
,
fwd_pd_
);
}
}
}
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
template
<
class
First
,
class
...
Args
>
typename
std
::
enable_if
<
std
::
is_same
<
typename
std
::
decay
<
First
>::
type
,
dnnl
::
primitive_attr
>::
value
>::
type
CreateForwardPrimitiveDescriptor
(
First
&&
first
,
Args
&&
...
args
)
{
auto
fwd_desc
=
typename
TForward
::
desc
(
std
::
forward
<
Args
>
(
args
)...);
fwd_pd_
=
std
::
make_shared
<
typename
TForward
::
primitive_desc
>
(
fwd_desc
,
first
,
engine_
);
}
template
<
class
First
,
class
...
Args
>
typename
std
::
enable_if
<!
std
::
is_same
<
typename
std
::
decay
<
First
>::
type
,
dnnl
::
primitive_attr
>::
value
>::
type
CreateForwardPrimitiveDescriptor
(
First
&&
first
,
Args
&&
...
args
)
{
auto
fwd_desc
=
typename
TForward
::
desc
(
std
::
forward
<
First
>
(
first
),
std
::
forward
<
Args
>
(
args
)...);
fwd_pd_
=
std
::
make_shared
<
typename
TForward
::
primitive_desc
>
(
fwd_desc
,
engine_
);
}
template
<
typename
...
Args
>
void
AcquireBackwardPrimitiveDescriptor
(
Args
&&
...
args
)
{
const
std
::
string
key_fwd_pd
=
key_common_
+
"@forward_pd"
;
...
...
@@ -162,6 +194,91 @@ class MKLDNNHandlerT {
return
mem_p
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryFromPrimitive
(
mkldnn
::
memory
::
desc
md
,
const
std
::
string
&
suffix
)
{
const
auto
local_key
=
key_
+
suffix
;
auto
mem_p
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx_
.
GetBlob
(
local_key
));
if
(
mem_p
==
nullptr
)
{
mem_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
md
,
engine_
);
dev_ctx_
.
SetBlob
(
local_key
,
mem_p
);
}
return
mem_p
;
}
void
AcquireReorder
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
user_memory_p
,
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
target_memory_p
,
const
std
::
string
&
suffix
)
{
const
auto
key_reorder_p
=
key_
+
suffix
+
"reorder_p"
;
auto
reorder_p
=
std
::
static_pointer_cast
<
mkldnn
::
reorder
>
(
dev_ctx_
.
GetBlob
(
key_reorder_p
));
if
(
reorder_p
==
nullptr
)
{
reorder_p
=
std
::
make_shared
<
mkldnn
::
reorder
>
(
*
user_memory_p
,
*
target_memory_p
);
dev_ctx_
.
SetBlob
(
key_reorder_p
,
reorder_p
);
}
mkldnn
::
stream
astream
(
engine_
);
reorder_p
->
execute
(
astream
,
{{
MKLDNN_ARG_FROM
,
*
user_memory_p
},
{
MKLDNN_ARG_TO
,
*
target_memory_p
}});
astream
.
wait
();
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryWithReorder
(
const
mkldnn
::
memory
::
desc
&
user_md
,
const
mkldnn
::
memory
::
desc
&
target_md
,
void
*
ptr
,
const
std
::
string
&
suffix
,
bool
is_persistent
=
false
)
{
const
auto
target_key
=
key_
+
suffix
+
"_target"
;
const
auto
key_reorder_p
=
key_
+
suffix
+
"reorder_p"
;
const
auto
user_key
=
key_
+
suffix
+
"_user"
;
auto
target_memory_p
=
std
::
static_pointer_cast
<
dnnl
::
memory
>
(
dev_ctx_
.
GetBlob
(
target_key
));
if
(
target_memory_p
==
nullptr
)
{
auto
user_memory_p
=
std
::
make_shared
<
dnnl
::
memory
>
(
user_md
,
engine_
,
ptr
);
if
(
user_md
!=
target_md
)
{
target_memory_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
target_md
,
engine_
);
auto
reorder_p
=
std
::
make_shared
<
dnnl
::
reorder
>
(
*
user_memory_p
,
*
target_memory_p
);
dev_ctx_
.
SetBlob
(
key_reorder_p
,
reorder_p
);
mkldnn
::
stream
astream
(
engine_
);
reorder_p
->
execute
(
astream
,
{{
MKLDNN_ARG_FROM
,
*
user_memory_p
},
{
MKLDNN_ARG_TO
,
*
target_memory_p
}});
astream
.
wait
();
}
else
{
target_memory_p
=
user_memory_p
;
}
dev_ctx_
.
SetBlob
(
user_key
,
user_memory_p
);
dev_ctx_
.
SetBlob
(
target_key
,
target_memory_p
);
}
else
if
(
!
is_persistent
)
{
mkldnn
::
stream
astream
(
engine_
);
auto
user_memory_p
=
std
::
static_pointer_cast
<
dnnl
::
memory
>
(
dev_ctx_
.
GetBlob
(
user_key
));
user_memory_p
->
set_data_handle
(
ptr
);
auto
reorder_p
=
std
::
static_pointer_cast
<
mkldnn
::
reorder
>
(
dev_ctx_
.
GetBlob
(
key_reorder_p
));
if
(
reorder_p
!=
nullptr
)
{
reorder_p
->
execute
(
astream
,
{{
MKLDNN_ARG_FROM
,
*
user_memory_p
},
{
MKLDNN_ARG_TO
,
*
target_memory_p
}});
astream
.
wait
();
}
}
return
target_memory_p
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemory
(
const
std
::
string
&
suffix
)
{
const
auto
local_key
=
key_
+
suffix
;
return
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx_
.
GetBlob
(
local_key
));
}
const
MKLDNNDeviceContext
&
dev_ctx_
;
mkldnn
::
engine
engine_
;
platform
::
Place
place_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录