Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
60647c9a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
60647c9a
编写于
6月 22, 2018
作者:
T
Tao Luo
提交者:
GitHub
6月 22, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11519 from jczaja/prv-softmax-mkldnn-grad-operator
MKLDNN: SoftmaxGrad Op
上级
3d1afe2e
98f3ad3b
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
318 addition
and
55 deletion
+318
-55
cmake/external/mkldnn.cmake
cmake/external/mkldnn.cmake
+1
-1
paddle/fluid/operators/softmax_mkldnn_op.cc
paddle/fluid/operators/softmax_mkldnn_op.cc
+167
-50
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+18
-4
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+132
-0
未找到文件。
cmake/external/mkldnn.cmake
浏览文件 @
60647c9a
...
@@ -54,7 +54,7 @@ ExternalProject_Add(
...
@@ -54,7 +54,7 @@ ExternalProject_Add(
${
EXTERNAL_PROJECT_LOG_ARGS
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
DEPENDS
${
MKLDNN_DEPENDS
}
DEPENDS
${
MKLDNN_DEPENDS
}
GIT_REPOSITORY
"https://github.com/01org/mkl-dnn.git"
GIT_REPOSITORY
"https://github.com/01org/mkl-dnn.git"
GIT_TAG
"
db3424ad44901513c03a1ea31ccaacdf633fbe9f
"
GIT_TAG
"
a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51
"
PREFIX
${
MKLDNN_SOURCES_DIR
}
PREFIX
${
MKLDNN_SOURCES_DIR
}
UPDATE_COMMAND
""
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
MKLDNN_INSTALL_DIR
}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
MKLDNN_INSTALL_DIR
}
...
...
paddle/fluid/operators/softmax_mkldnn_op.cc
浏览文件 @
60647c9a
...
@@ -27,8 +27,81 @@ using paddle::platform::MKLDNNMemDesc;
...
@@ -27,8 +27,81 @@ using paddle::platform::MKLDNNMemDesc;
using
mkldnn
::
memory
;
// Note: paddle has also "memory" namespace
using
mkldnn
::
memory
;
// Note: paddle has also "memory" namespace
using
mkldnn
::
primitive
;
using
mkldnn
::
primitive
;
using
mkldnn
::
softmax_forward
;
using
mkldnn
::
softmax_forward
;
using
mkldnn
::
softmax_backward
;
using
mkldnn
::
prop_kind
;
using
mkldnn
::
prop_kind
;
using
mkldnn
::
stream
;
using
mkldnn
::
stream
;
using
platform
::
to_void_cast
;
class
SoftmaxMKLDNNHandler
:
public
platform
::
MKLDNNHandler
{
public:
SoftmaxMKLDNNHandler
(
std
::
shared_ptr
<
mkldnn
::
softmax_forward
::
primitive_desc
>
softmax_pd
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
),
softmax_pd_
(
softmax_pd
)
{}
SoftmaxMKLDNNHandler
(
std
::
shared_ptr
<
mkldnn
::
softmax_forward
::
primitive_desc
>
softmax_pd
,
std
::
shared_ptr
<
mkldnn
::
softmax_backward
::
primitive_desc
>
softmax_bwd_pd
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
),
softmax_pd_
(
softmax_pd
),
softmax_bwd_pd_
(
softmax_bwd_pd
)
{
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
key_
+=
"-BWD"
;
}
std
::
shared_ptr
<
mkldnn
::
softmax_forward
>
AcquireSoftmax
(
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
/*Generate key*/
auto
prim_key
=
key_
+
"@softmax_p"
;
auto
softmax_p
=
std
::
static_pointer_cast
<
mkldnn
::
softmax_forward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
((
softmax_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find softmax primitive in device context"
);
if
(
softmax_p
==
nullptr
)
{
softmax_p
=
std
::
make_shared
<
mkldnn
::
softmax_forward
>
(
*
(
softmax_pd_
.
get
()),
*
(
static_cast
<
mkldnn
::
memory
*>
(
src_memory_p
.
get
())),
*
(
static_cast
<
mkldnn
::
memory
*>
(
dst_memory_p
.
get
())));
dev_ctx_
.
SetBlob
(
prim_key
,
softmax_p
);
}
else
{
is_reusing_
=
true
;
}
return
softmax_p
;
}
std
::
shared_ptr
<
mkldnn
::
softmax_backward
>
AcquireSoftmaxBackward
(
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_src_memory_p
)
{
auto
prim_key
=
key_
+
"@softmax_bwd_p"
;
auto
softmax_bwd_p
=
std
::
static_pointer_cast
<
mkldnn
::
softmax_backward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
((
softmax_bwd_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find softmax backward primitive in device context"
);
if
(
softmax_bwd_p
==
nullptr
)
{
softmax_bwd_p
=
std
::
make_shared
<
mkldnn
::
softmax_backward
>
(
*
softmax_bwd_pd_
,
*
(
dst_memory_p
.
get
()),
*
(
diff_dst_memory_p
.
get
()),
*
(
diff_src_memory_p
.
get
()));
dev_ctx_
.
SetBlob
(
prim_key
,
softmax_bwd_p
);
}
else
{
is_reusing_
=
true
;
}
return
softmax_bwd_p
;
}
private:
std
::
shared_ptr
<
mkldnn
::
softmax_forward
::
primitive_desc
>
softmax_pd_
;
std
::
shared_ptr
<
mkldnn
::
softmax_backward
::
primitive_desc
>
softmax_bwd_pd_
;
};
template
<
typename
T
>
template
<
typename
T
>
class
SoftmaxMKLDNNKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
class
SoftmaxMKLDNNKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
...
@@ -54,56 +127,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
...
@@ -54,56 +127,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
// Same memory descriptor to be used for input and output
// Same memory descriptor to be used for input and output
memory
::
dims
softmax_tz
=
{
src_tz
[
0
],
src_tz
[
1
]};
memory
::
dims
softmax_tz
=
{
src_tz
[
0
],
src_tz
[
1
]};
// Generate keys for storing/retriving primitives for this operator
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
const
std
::
string
key
=
auto
gethash
=
[](
memory
::
dims
&
operand_dims
)
{
platform
::
MKLDNNHandler
::
GetHash
(
softmax_tz
,
ctx
.
op
().
Output
(
"Out"
));
return
std
::
string
(
std
::
to_string
(
operand_dims
[
0
])
+
"-"
+
const
std
::
string
key_softmax_pd
=
key
+
"@softmax_pd"
;
std
::
to_string
(
operand_dims
[
1
]));
};
const
std
::
string
key
=
gethash
(
softmax_tz
);
const
std
::
string
key_softmax_p
=
key
+
"@softmax_p"
;
const
std
::
string
key_softmax_src_mem_p
=
key
+
"@softmax_src_mem_p"
;
const
std
::
string
key_softmax_dst_mem_p
=
key
+
"@softmax_dst_mem_p"
;
std
::
shared_ptr
<
void
>
softmax_p
=
dev_ctx
.
GetBlob
(
key_softmax_p
);
if
(
softmax_p
==
nullptr
)
{
// Currently only NC data format is supported
// Currently only NC data format is supported
auto
softmax_md
=
auto
softmax_md
=
MKLDNNMemDesc
(
MKLDNNMemDesc
({
softmax_tz
},
memory
::
f32
,
memory
::
format
::
nc
);
{
softmax_tz
},
platform
::
MKLDNNGetDataType
<
T
>
()
,
memory
::
format
::
nc
);
// Normalization is made after innermost dimension eg. C out of NC
// Normalization is made after innermost dimension eg. C out of NC
auto
softmax_desc
=
softmax_forward
::
desc
(
prop_kind
::
forward_scoring
,
auto
softmax_desc
=
softmax_forward
::
desc
(
prop_kind
::
forward_scoring
,
softmax_md
,
1
/*dim: C*/
);
softmax_md
,
1
/*dim: C*/
);
// create memory primitives
auto
softmax_pd
=
std
::
make_shared
<
mkldnn
::
softmax_forward
::
primitive_desc
>
(
auto
softmax_src_memory_p
=
std
::
make_shared
<
memory
>
(
softmax_desc
,
mkldnn_engine
);
memory
::
primitive_desc
{
softmax_md
,
mkldnn_engine
},
dev_ctx
.
SetBlob
(
key_softmax_pd
,
softmax_pd
);
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
dev_ctx
.
SetBlob
(
key_softmax_src_mem_p
,
softmax_src_memory_p
);
SoftmaxMKLDNNHandler
handler
(
softmax_pd
,
dev_ctx
,
mkldnn_engine
,
key
);
auto
softmax_dst_memory_p
=
std
::
make_shared
<
memory
>
(
auto
softmax_src_memory_p
=
memory
::
primitive_desc
{
softmax_md
,
mkldnn_engine
},
handler
.
AcquireSrcMemory
(
softmax_md
,
to_void_cast
<
T
>
(
input_data
));
static_cast
<
void
*>
(
output_data
));
auto
softmax_dst_memory_p
=
dev_ctx
.
SetBlob
(
key_softmax_dst_mem_p
,
softmax_dst_memory_p
);
handler
.
AcquireDstMemory
(
softmax_md
,
to_void_cast
<
T
>
(
output_data
));
auto
softmax_p
=
auto
softmax_forward_pd
=
handler
.
AcquireSoftmax
(
softmax_dst_memory_p
,
softmax_src_memory_p
);
std
::
make_shared
<
softmax_forward
::
primitive_desc
>
(
softmax_desc
,
mkldnn_engine
);
softmax_p
=
std
::
make_shared
<
softmax_forward
>
(
*
(
softmax_forward_pd
.
get
()),
*
(
static_cast
<
memory
*>
(
softmax_src_memory_p
.
get
())),
*
(
static_cast
<
memory
*>
(
softmax_dst_memory_p
.
get
())));
dev_ctx
.
SetBlob
(
key_softmax_p
,
softmax_p
);
}
else
{
// Primitives already exist
auto
src_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_softmax_src_mem_p
));
PADDLE_ENFORCE
(
src_memory_p
!=
nullptr
,
"Fail to find softmax src mem_p in device context"
);
auto
dst_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_softmax_dst_mem_p
));
PADDLE_ENFORCE
(
dst_memory_p
!=
nullptr
,
"Fail to find softmax dst mem_p in device context"
);
src_memory_p
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
dst_memory_p
->
set_data_handle
(
output_data
);
}
std
::
vector
<
primitive
>
pipeline
{
std
::
vector
<
primitive
>
pipeline
{
*
(
static_cast
<
softmax_forward
::
primitive
*>
(
softmax_p
.
get
()))};
*
(
static_cast
<
softmax_forward
::
primitive
*>
(
softmax_p
.
get
()))};
...
@@ -120,6 +164,77 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
...
@@ -120,6 +164,77 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
}
};
};
template
<
typename
T
>
class
SoftmaxMKLDNNGradKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
Tensor
*
output
=
ctx
.
Input
<
Tensor
>
(
"Out"
);
const
T
*
dst_data
=
output
->
data
<
T
>
();
auto
*
dout
=
ctx
.
template
Input
<
Tensor
>(
framework
::
GradVarName
(
"Out"
));
const
auto
*
diff_dst_ptr
=
dout
->
template
data
<
T
>();
auto
*
dx
=
ctx
.
template
Output
<
framework
::
Tensor
>(
framework
::
GradVarName
(
"X"
));
T
*
diff_src_ptr
=
dx
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
std
::
vector
<
int
>
src_tz
(
dst_tz
);
PADDLE_ENFORCE
(
output
->
dims
().
size
()
==
2UL
,
"The input of softmax op must be a 2D matrix."
);
// MKL-DNN does support softmax over selected axis. Having 2D Tensor,
// we will make normalization after final eg. axis: 1
PADDLE_ENFORCE
(((
src_tz
[
0
]
==
dst_tz
[
0
])
&&
(
src_tz
[
1
]
==
dst_tz
[
1
])),
"Softmax input and output dimensions should match"
);
// Same memory descriptor to be used for input and output
memory
::
dims
softmax_tz
=
{
src_tz
[
0
],
src_tz
[
1
]};
// Currently only supports NC data format
// retrieve eltwise primitive desc from device context
const
std
::
string
key
=
platform
::
MKLDNNHandler
::
GetHash
(
softmax_tz
,
ctx
.
op
().
Input
(
"Out"
));
const
std
::
string
key_softmax_pd
=
key
+
"@softmax_pd"
;
auto
softmax_pd
=
std
::
static_pointer_cast
<
mkldnn
::
softmax_forward
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_softmax_pd
));
PADDLE_ENFORCE
(
softmax_pd
!=
nullptr
,
"Fail to find softmax_pd in device context"
);
// TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format
auto
data_softmax_md
=
MKLDNNMemDesc
(
{
softmax_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
memory
::
format
::
nc
);
auto
diff_softmax_md
=
MKLDNNMemDesc
(
{
softmax_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
memory
::
format
::
nc
);
// Normalization is made after innermost dimension eg. C out of NC
auto
softmax_bwd_desc
=
softmax_backward
::
desc
(
diff_softmax_md
,
data_softmax_md
,
1
/* dim: C*/
);
auto
softmax_bwd_pd
=
std
::
make_shared
<
mkldnn
::
softmax_backward
::
primitive_desc
>
(
softmax_bwd_desc
,
mkldnn_engine
,
*
softmax_pd
);
SoftmaxMKLDNNHandler
handler
(
softmax_pd
,
softmax_bwd_pd
,
dev_ctx
,
mkldnn_engine
,
key
);
auto
dst_memory_p
=
handler
.
AcquireDstMemory
(
data_softmax_md
,
to_void_cast
<
T
>
(
dst_data
));
auto
diff_dst_memory_p
=
handler
.
AcquireDiffDstMemory
(
diff_softmax_md
,
to_void_cast
<
T
>
(
diff_dst_ptr
));
auto
diff_src_memory_p
=
handler
.
AcquireDiffSrcMemory
(
diff_softmax_md
,
to_void_cast
<
T
>
(
diff_src_ptr
));
// Get primitve from device context
auto
softmax_bwd_p
=
handler
.
AcquireSoftmaxBackward
(
dst_memory_p
,
diff_dst_memory_p
,
diff_src_memory_p
);
std
::
vector
<
primitive
>
pipeline
{
*
softmax_bwd_p
};
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -127,3 +242,5 @@ namespace ops = paddle::operators;
...
@@ -127,3 +242,5 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL
(
softmax
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_KERNEL
(
softmax
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
ops
::
SoftmaxMKLDNNKernel
<
float
>
);
ops
::
SoftmaxMKLDNNKernel
<
float
>
);
REGISTER_OP_KERNEL
(
softmax_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
ops
::
SoftmaxMKLDNNGradKernel
<
float
>
);
paddle/fluid/operators/softmax_op.cc
浏览文件 @
60647c9a
...
@@ -145,16 +145,30 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
...
@@ -145,16 +145,30 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
// choose cudnn kernel if the runtime supported.
// choose cudnn kernel if the runtime supported.
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
#ifdef PADDLE_WITH_MKLDNN
return
framework
::
OpKernelType
(
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
framework
::
StringToDataLayout
(
data_format
),
library_
);
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
#endif
auto
input_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP16
)
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"float16 can only be used on GPU place"
);
}
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout_
,
library_
);
}
}
};
};
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
60647c9a
...
@@ -105,5 +105,137 @@ inline mkldnn::memory::format GetMKLDNNFormat(
...
@@ -105,5 +105,137 @@ inline mkldnn::memory::format GetMKLDNNFormat(
memory
.
dst_primitive_desc
().
desc
().
data
.
format
);
memory
.
dst_primitive_desc
().
desc
().
data
.
format
);
}
}
class
MKLDNNHandler
{
public:
MKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
dev_ctx_
(
dev_ctx
),
engine_
(
engine
),
key_
(
base_key
),
is_reusing_
(
false
)
{}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWeightsMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_weights_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffDstMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_diff_dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_diff_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryFromPrimitive
(
mkldnn
::
memory
::
primitive_desc
mdp
,
void
*
ptr
,
const
std
::
string
&
suffix
)
{
auto
local_key
=
key_
+
suffix
;
auto
mem_p
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx_
.
GetBlob
(
local_key
));
PADDLE_ENFORCE
((
mem_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find mem primitive in device context"
);
if
(
mem_p
==
nullptr
)
{
mem_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mdp
,
ptr
);
dev_ctx_
.
SetBlob
(
local_key
,
mem_p
);
}
else
{
mem_p
->
set_data_handle
(
ptr
);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_
=
true
;
}
return
mem_p
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
,
const
std
::
string
&
suffix
)
{
/*Generate key*/
auto
local_key
=
key_
+
suffix
;
auto
mem_p
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx_
.
GetBlob
(
local_key
));
PADDLE_ENFORCE
((
mem_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find mem primitive in device context"
);
if
(
mem_p
==
nullptr
)
{
mem_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
::
primitive_desc
{
md
,
engine_
},
ptr
);
dev_ctx_
.
SetBlob
(
local_key
,
mem_p
);
}
else
{
mem_p
->
set_data_handle
(
ptr
);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_
=
true
;
}
return
mem_p
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemory
(
mkldnn
::
memory
::
primitive_desc
&
mpd
,
mkldnn
::
memory
::
primitive_desc
&
user_mpd
,
const
std
::
shared_ptr
<
mkldnn
::
memory
>
user_memory_p
,
const
std
::
string
&
suffix
,
std
::
vector
<
mkldnn
::
primitive
>&
pipeline
)
{
// create reorder primitive if the input format is not the preferred one
auto
local_key
=
key_
+
suffix
;
auto
key_reorder_p
=
key_
+
suffix
+
"reorder_p"
;
auto
target_memory_p
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx_
.
GetBlob
(
local_key
));
PADDLE_ENFORCE
((
target_memory_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find mem primitive in device context"
);
if
(
target_memory_p
==
nullptr
)
{
target_memory_p
=
user_memory_p
;
std
::
shared_ptr
<
mkldnn
::
primitive
>
reorder_p
;
if
(
mpd
!=
user_mpd
)
{
target_memory_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mpd
);
auto
reorder_p
=
std
::
make_shared
<
mkldnn
::
reorder
>
(
*
user_memory_p
,
*
target_memory_p
);
dev_ctx_
.
SetBlob
(
key_reorder_p
,
reorder_p
);
pipeline
.
push_back
(
*
reorder_p
);
}
dev_ctx_
.
SetBlob
(
local_key
,
target_memory_p
);
}
else
{
// Make reorder if needed
auto
reorder_p
=
std
::
static_pointer_cast
<
mkldnn
::
reorder
>
(
dev_ctx_
.
GetBlob
(
key_reorder_p
));
if
(
reorder_p
!=
nullptr
)
{
pipeline
.
push_back
(
*
reorder_p
);
}
is_reusing_
=
true
;
}
return
target_memory_p
;
}
static
std
::
string
GetHash
(
mkldnn
::
memory
::
dims
&
operand_dims
,
const
std
::
string
&
suffix
)
{
auto
dims2str
=
[](
const
mkldnn
::
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
};
return
dims2str
(
operand_dims
)
+
suffix
;
};
protected:
const
MKLDNNDeviceContext
&
dev_ctx_
;
mkldnn
::
engine
engine_
;
std
::
string
key_
;
bool
is_reusing_
;
};
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录