Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
792d3b24
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
792d3b24
编写于
6月 11, 2018
作者:
M
mozga-intel
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN layout: Support for activation operator
上级
d7345959
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
212 addition
and
133 deletion
+212
-133
paddle/fluid/operators/activation_mkldnn_op.cc
paddle/fluid/operators/activation_mkldnn_op.cc
+196
-120
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+16
-13
未找到文件。
paddle/fluid/operators/activation_mkldnn_op.cc
浏览文件 @
792d3b24
...
@@ -12,16 +12,20 @@
...
@@ -12,16 +12,20 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
paddle
::
framework
::
Tensor
;
using
framework
::
DataLayout
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
framework
::
Tensor
;
using
mkldnn
::
memory
;
using
mkldnn
::
primitive
;
using
mkldnn
::
stream
;
using
platform
::
GetMKLDNNFormat
;
using
platform
::
MKLDNNDeviceContext
;
using
platform
::
to_void_cast
;
namespace
{
namespace
{
std
::
string
gethash
(
const
mkldnn
::
memory
::
dims
&
operand_dims
,
std
::
string
gethash
(
const
mkldnn
::
memory
::
dims
&
operand_dims
,
...
@@ -35,188 +39,260 @@ std::string gethash(const mkldnn::memory::dims &operand_dims,
...
@@ -35,188 +39,260 @@ std::string gethash(const mkldnn::memory::dims &operand_dims,
};
};
return
dim2str
(
operand_dims
)
+
std
::
to_string
(
algorithm
);
return
dim2str
(
operand_dims
)
+
std
::
to_string
(
algorithm
);
}
}
}
// namespace
template
<
typename
Functor
>
class
MKLDNNActivationKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
PADDLE_ENFORCE
(
x
->
layout
()
==
DataLayout
::
kMKLDNN
&&
x
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Input x tensor"
);
Functor
functor
;
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
ctx
.
Attr
<
float
>
(
attr
.
first
);
}
functor
(
ctx
);
}
};
template
<
typename
T
,
typename
ExecContext
>
template
<
typename
Functor
>
void
eltwise_forward
(
const
ExecContext
&
ctx
,
mkldnn
::
algorithm
algorithm
,
class
MKLDNNActivationGradKernel
const
T
alpha
=
0
,
const
T
beta
=
0
)
{
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
diff_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
PADDLE_ENFORCE
(
diff_y
->
layout
()
==
DataLayout
::
kMKLDNN
&&
diff_y
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Input OutGrad tensor"
);
Functor
functor
;
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
ctx
.
Attr
<
float
>
(
attr
.
first
);
}
functor
(
ctx
);
}
};
template
<
typename
T
>
void
eltwise_forward
(
const
framework
::
ExecutionContext
&
ctx
,
mkldnn
::
algorithm
algorithm
,
const
T
alpha
=
0
,
const
T
beta
=
0
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
"It must use CPUPlace."
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
// get buffers
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
src
=
ctx
.
template
Input
<
Tensor
>(
"X"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
const
auto
*
src_data
=
src
->
template
data
<
T
>();
auto
*
dst
=
ctx
.
template
Output
<
Tensor
>(
"Out"
);
const
T
*
x_data
=
x
->
data
<
T
>
(
);
T
*
dst_data
=
dst
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// get memory dim
PADDLE_ENFORCE
(
x
->
dims
().
size
()
==
2
||
x
->
dims
().
size
()
==
4
,
PADDLE_ENFORCE
(
src
->
dims
().
size
()
==
2
||
src
->
dims
().
size
()
==
4
,
"Input dim must be with 2 or 4"
);
"Input dim must be with 2 or 4"
);
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
src
->
dims
());
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
x
->
dims
());
auto
src_format
=
src_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
x
->
format
();
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_data
=
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Output
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
key
+
ctx
.
op
().
Output
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_src_layout
=
const
std
::
string
key_dst_mem
=
key
+
"@eltwise_fwd_dst_mem"
;
key
+
ctx
.
op
().
Output
(
"Out"
)
+
"@eltwise_fwd_src_layout"
;
const
std
::
string
key_fwd
=
key
+
"@eltwise_fwd"
;
const
std
::
string
key_with_layout
=
key
+
std
::
to_string
(
src_format
);
const
std
::
string
key_src_mem
=
key_with_layout
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_dst_mem
=
key_with_layout
+
"@eltwise_fwd_dst_mem"
;
const
std
::
string
key_fwd
=
key_with_layout
+
"@eltwise_fwd"
;
const
std
::
string
key_fwd_pd
=
key_with_layout
+
"@eltwise_fwd_pd"
;
// save input data and layout to be referred in backward path
auto
p_src_data
=
std
::
make_shared
<
const
T
*>
(
x_data
);
dev_ctx
.
SetBlob
(
key_src_data
,
p_src_data
);
auto
p_src_layout
=
std
::
make_shared
<
memory
::
format
>
(
src_format
);
dev_ctx
.
SetBlob
(
key_src_layout
,
p_src_layout
);
auto
p_fwd
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
auto
p_fwd
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
dev_ctx
.
GetBlob
(
key_fwd
));
dev_ctx
.
GetBlob
(
key_fwd
));
// save input data to be referred in backward path
std
::
shared_ptr
<
memory
>
dst_memory
;
auto
p_src_data
=
std
::
make_shared
<
const
T
*>
(
src_data
);
dev_ctx
.
SetBlob
(
key_src_data
,
p_src_data
);
if
(
p_fwd
==
nullptr
)
{
if
(
p_fwd
==
nullptr
)
{
// create memory description
// create mkldnn memory for input X
auto
data_md
=
src_tz
.
size
()
==
2
auto
src_md
=
platform
::
MKLDNNMemDesc
(
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
mkldnn
::
memory
::
format
::
nc
)
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
:
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
new
memory
({
src_md
,
mkldnn_engine
},
to_void_cast
(
x_data
)));
mkldnn
::
memory
::
format
::
nchw
);
// save src_memory to be referred in backward path
dev_ctx
.
SetBlob
(
key_src_mem
,
src_memory
);
// create memory primitives
auto
p_src_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
// create primitive descriptor for activation forward and save it
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
src_data
)));
auto
forward_desc
=
mkldnn
::
eltwise_forward
::
desc
(
dev_ctx
.
SetBlob
(
key_src_mem
,
p_src_mem
);
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
src_memory
->
get_primitive_desc
().
desc
(),
alpha
,
beta
);
auto
p_dst_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
auto
forward_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
dst_data
)));
forward_desc
,
mkldnn_engine
);
dev_ctx
.
SetBlob
(
key_dst_mem
,
p_dst_mem
);
// save prim desc into global device context to be referred in backward path
auto
fwd_desc
=
mkldnn
::
eltwise_forward
::
desc
(
dev_ctx
.
SetBlob
(
key_fwd_pd
,
forward_pd
);
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
data_md
,
alpha
,
beta
);
auto
p_fwd_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
// create mkldnn memory for output y
fwd_desc
,
mkldnn_engine
);
dst_memory
=
const
std
::
string
key_fwd_pd
=
key
+
"eltwise_fwd_pd"
;
std
::
make_shared
<
memory
>
(
forward_pd
->
dst_primitive_desc
(),
y_data
);
dev_ctx
.
SetBlob
(
key_fwd_pd
,
p_fwd_pd
);
p_fwd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
dev_ctx
.
SetBlob
(
key_dst_mem
,
dst_memory
);
*
p_fwd_pd
,
*
(
p_src_mem
.
get
()),
*
(
p_dst_mem
.
get
()));
// create activation primitive
p_fwd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
*
forward_pd
,
*
src_memory
,
*
dst_memory
);
dev_ctx
.
SetBlob
(
key_fwd
,
p_fwd
);
dev_ctx
.
SetBlob
(
key_fwd
,
p_fwd
);
}
else
{
}
else
{
// primitives already exist
// primitives already exist
auto
p_src_mem
=
auto
src_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
PADDLE_ENFORCE
(
p_src_mem
!=
nullptr
,
PADDLE_ENFORCE
(
src_memory
!=
nullptr
,
"Fail to find eltwise
p_src_mem
in device context."
);
"Fail to find eltwise
src_memory
in device context."
);
auto
p_dst_mem
=
dst_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_dst_mem
));
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_dst_mem
));
PADDLE_ENFORCE
(
p_dst_mem
!=
nullptr
,
PADDLE_ENFORCE
(
dst_memory
!=
nullptr
,
"Fail to find eltwise
p_src_mem
in device context."
);
"Fail to find eltwise
dst_memory
in device context."
);
p_src_mem
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
src
_data
));
src_memory
->
set_data_handle
(
platform
::
to_void_cast
(
x
_data
));
p_dst_mem
->
set_data_handle
(
dst
_data
);
dst_memory
->
set_data_handle
(
y
_data
);
}
}
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
*
(
p_fwd
.
get
())};
std
::
vector
<
primitive
>
pipeline
;
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
pipeline
.
push_back
(
*
p_fwd
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
));
}
}
template
<
typename
T
,
typename
ExecContext
>
template
<
typename
T
>
void
eltwise_grad
(
const
ExecContext
&
ctx
,
mkldnn
::
algorithm
algorithm
,
void
eltwise_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
T
alpha
=
0
,
const
T
beta
=
0
)
{
mkldnn
::
algorithm
algorithm
,
const
T
alpha
=
0
,
const
T
beta
=
0
)
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
// get buffers
const
auto
*
diff_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
auto
*
out
=
ctx
.
template
Input
<
Tensor
>(
"Out"
);
auto
*
diff_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dout
=
ctx
.
template
Input
<
Tensor
>(
framework
::
GradVarName
(
"Out"
));
const
auto
*
diff_dst
=
dout
->
template
data
<
T
>();
auto
*
dx
=
const
T
*
diff_y_data
=
diff_y
->
data
<
T
>
();
ctx
.
template
Output
<
framework
::
Tensor
>(
framework
::
GradVarName
(
"X"
));
T
*
diff_x_data
=
diff_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
diff_src
=
dx
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
// get memory dim
std
::
vector
<
int
>
diff_dst_tz
=
framework
::
vectorize2int
(
diff_y
->
dims
());
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
out
->
dims
());
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
auto
diff_y_format
=
const
std
::
string
key_diff_src_mem
=
key
+
"@eltwise_diff_src_mem"
;
diff_dst_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
diff_y
->
format
();
const
std
::
string
key_diff_dst_mem
=
key
+
"@eltwise_diff_dst_mem"
;
const
std
::
string
key_grad
=
key
+
"@eltwise_grad"
;
const
std
::
string
key
=
gethash
(
diff_dst_tz
,
algorithm
);
const
std
::
string
key_src_data
=
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_layout"
;
const
auto
p_src_layout
=
std
::
static_pointer_cast
<
memory
::
format
>
(
dev_ctx
.
GetBlob
(
key_src_layout
));
const
std
::
string
key_src_mem
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_fwd_pd
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_pd"
;
const
std
::
string
key_with_layouts
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"-"
+
std
::
to_string
(
diff_y_format
);
const
std
::
string
key_diff_src_mem
=
key_with_layouts
+
"@eltwise_diff_src_mem"
;
const
std
::
string
key_diff_dst_mem
=
key_with_layouts
+
"@eltwise_diff_dst_mem"
;
const
std
::
string
key_grad
=
key_with_layouts
+
"@eltwise_grad"
;
const
auto
p_src_data
=
const
auto
p_src_data
=
std
::
static_pointer_cast
<
T
*>
(
dev_ctx
.
GetBlob
(
key_src_data
));
std
::
static_pointer_cast
<
T
*>
(
dev_ctx
.
GetBlob
(
key_src_data
));
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
auto
src_memory
=
auto
p_src_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
p_src_mem
->
set_data_handle
(
*
p_src_data
.
get
());
PADDLE_ENFORCE
(
src_memory
!=
nullptr
,
"Fail to find src_memory in device context"
);
src_memory
->
set_data_handle
(
*
p_src_data
.
get
());
std
::
shared_ptr
<
memory
>
diff_src_memory
;
auto
p_grad
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_
forward
::
primitive
>
(
auto
p_grad
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_
backward
>
(
dev_ctx
.
GetBlob
(
key_grad
));
dev_ctx
.
GetBlob
(
key_grad
));
if
(
p_grad
==
nullptr
)
{
if
(
p_grad
==
nullptr
)
{
// create m
emory description
// create m
kldnn memory for input diff_y
auto
d
ata_md
=
src_tz
.
size
()
==
2
auto
d
iff_dst_md
=
platform
::
MKLDNNMemDesc
(
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_y_format
);
mkldnn
::
memory
::
format
::
nc
)
auto
diff_dst_memory
=
std
::
shared_ptr
<
memory
>
(
:
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
new
memory
({
diff_dst_md
,
mkldnn_engine
},
to_void_cast
(
diff_y_data
)));
mkldnn
::
memory
::
format
::
nchw
);
dev_ctx
.
SetBlob
(
key_diff_dst_mem
,
diff_dst_memory
);
//
create memory primitives
//
retrieve eltwise primitive desc from device context
std
::
shared_ptr
<
void
>
p_diff_src_mem
=
auto
forward_pd
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
diff_src
)
));
dev_ctx
.
GetBlob
(
key_fwd_pd
));
dev_ctx
.
SetBlob
(
key_diff_src_mem
,
p_diff_src_mem
);
PADDLE_ENFORCE
(
forward_pd
!=
nullptr
,
std
::
shared_ptr
<
void
>
p_diff_dst_mem
=
"Fail to find eltwise_fwd_pd in device context"
);
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
diff_dst
)));
// ceate primitive descriptor for activation backward
dev_ctx
.
SetBlob
(
key_diff_dst_mem
,
p_diff_dst_mem
);
auto
backward_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
diff_dst_memory
->
get_primitive_desc
().
desc
(),
auto
bwd_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
data_md
,
data_md
,
src_memory
->
get_primitive_desc
().
desc
(),
alpha
,
beta
);
alpha
,
beta
);
auto
backward_pd
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
backward_desc
,
mkldnn_engine
,
*
forward_pd
);
const
std
::
string
key_fwd_pd
=
key
+
"eltwise_fwd_pd"
;
auto
*
p_fwd_pd
=
static_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
*>
(
// create mkldnn memory for output diff_src
dev_ctx
.
GetBlob
(
key_fwd_pd
).
get
());
diff_src_memory
=
std
::
make_shared
<
memory
>
(
backward_pd
.
diff_src_primitive_desc
(),
diff_x_data
);
auto
eltwise_bwd_prim_desc
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
dev_ctx
.
SetBlob
(
key_diff_src_mem
,
diff_src_memory
);
bwd_desc
,
mkldnn_engine
,
*
p_fwd_pd
);
// create activation backward primitive
p_grad
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
p_grad
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
eltwise_bwd_prim_desc
,
*
static_cast
<
mkldnn
::
memory
*>
(
p_src_mem
.
get
()),
backward_pd
,
*
src_memory
,
*
diff_dst_memory
,
*
diff_src_memory
);
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_diff_dst_mem
.
get
())),
dev_ctx
.
SetBlob
(
key_grad
,
p_grad
);
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_diff_src_mem
.
get
())));
}
else
{
}
else
{
// primitives already exist
// primitives already exist
auto
p_diff_src_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
diff_src_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_src_mem
));
dev_ctx
.
GetBlob
(
key_diff_src_mem
));
auto
p_diff_dst_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
auto
diff_dst_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_dst_mem
));
dev_ctx
.
GetBlob
(
key_diff_dst_mem
));
p_diff_src_mem
->
set_data_handle
(
diff_src_memory
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_
src
));
platform
::
to_void_reinterpret_cast
(
diff_
x_data
));
p_diff_dst_mem
->
set_data_handle
(
diff_dst_memory
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_
dst
));
platform
::
to_void_reinterpret_cast
(
diff_
y_data
));
}
}
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
*
(
p_grad
.
get
())};
std
::
vector
<
primitive
>
pipeline
;
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
pipeline
.
push_back
(
*
p_grad
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory
));
}
}
}
// anonymous namespace
template
<
typename
T
,
mkldnn
::
algorithm
algorithm
>
template
<
typename
T
,
mkldnn
::
algorithm
algorithm
>
struct
MKLDNNActivationFunc
:
public
BaseActivationFunctor
<
T
>
{
struct
MKLDNNActivationFunc
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
ExecContext
>
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
operator
()(
const
ExecContext
&
ctx
)
const
{
eltwise_forward
<
T
>
(
ctx
,
algorithm
);
eltwise_forward
<
T
>
(
ctx
,
algorithm
);
}
}
};
};
template
<
typename
T
,
mkldnn
::
algorithm
algorithm
>
template
<
typename
T
,
mkldnn
::
algorithm
algorithm
>
struct
MKLDNNActivationGradFunc
:
public
BaseActivationFunctor
<
T
>
{
struct
MKLDNNActivationGradFunc
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
ExecContext
>
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
operator
()(
const
ExecContext
&
ctx
)
const
{
eltwise_grad
<
T
>
(
ctx
,
algorithm
);
eltwise_grad
<
T
>
(
ctx
,
algorithm
);
}
}
};
};
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
792d3b24
...
@@ -19,6 +19,8 @@ limitations under the License. */
...
@@ -19,6 +19,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
paddle
::
framework
::
Tensor
;
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
...
@@ -27,9 +29,9 @@ namespace operators {
...
@@ -27,9 +29,9 @@ namespace operators {
AddInput("X", "Input of " #OP_NAME " operator"); \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
AddAttr<bool>("use_mkldnn", \
AddAttr<bool>("use_mkldnn", \
"(
default false) Only used in mkldnn kernel")
\
"(
bool, default false) Only used in mkldnn kernel")
\
.SetDefault(false); \
.SetDefault(false); \
AddComment(
OP_COMMENT);
\
AddComment(
#OP_COMMENT);
\
} \
} \
}
}
...
@@ -58,7 +60,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
...
@@ -58,7 +60,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const
framework
::
OperatorWithKernel
&
oper
,
const
framework
::
OperatorWithKernel
&
oper
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
auto
it
=
oper
.
Attrs
().
find
(
"use_mkldnn"
);
auto
it
=
oper
.
Attrs
().
find
(
"use_mkldnn"
);
...
@@ -82,6 +83,7 @@ class ActivationOp : public framework::OperatorWithKernel {
...
@@ -82,6 +83,7 @@ class ActivationOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
GetKernelType
(
ctx
,
*
this
,
"X"
);
return
GetKernelType
(
ctx
,
*
this
,
"X"
);
...
@@ -96,6 +98,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
...
@@ -96,6 +98,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
GetKernelType
(
ctx
,
*
this
,
"Out"
);
return
GetKernelType
(
ctx
,
*
this
,
"Out"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录