Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7205d331
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7205d331
编写于
5月 22, 2018
作者:
T
tensor-tang
提交者:
GitHub
5月 22, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10597 from kbinias/mkldnn-activations-improvments
Update activations for MKL-DNN
上级
2a77fc50
24904b91
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
172 addition
and
117 deletion
+172
-117
paddle/fluid/operators/activation_mkldnn_op.cc
paddle/fluid/operators/activation_mkldnn_op.cc
+135
-67
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+30
-3
paddle/fluid/operators/mkldnn_activation_op.h
paddle/fluid/operators/mkldnn_activation_op.h
+2
-47
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+5
-0
未找到文件。
paddle/fluid/operators/activation_mkldnn_op.cc
浏览文件 @
7205d331
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "mkldnn.hpp"
#include "mkldnn.hpp"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -23,6 +24,18 @@ using paddle::framework::Tensor;
...
@@ -23,6 +24,18 @@ using paddle::framework::Tensor;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
namespace
{
namespace
{
std
::
string
gethash
(
const
mkldnn
::
memory
::
dims
&
operand_dims
,
const
mkldnn
::
algorithm
algorithm
)
{
auto
dim2str
=
[](
const
mkldnn
::
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
};
return
dim2str
(
operand_dims
)
+
std
::
to_string
(
algorithm
);
}
template
<
typename
T
,
typename
ExecContext
>
template
<
typename
T
,
typename
ExecContext
>
void
eltwise_forward
(
const
ExecContext
&
ctx
,
mkldnn
::
algorithm
algorithm
,
void
eltwise_forward
(
const
ExecContext
&
ctx
,
mkldnn
::
algorithm
algorithm
,
const
T
alpha
=
0
,
const
T
beta
=
0
)
{
const
T
alpha
=
0
,
const
T
beta
=
0
)
{
...
@@ -37,42 +50,70 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -37,42 +50,70 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const
auto
*
src_data
=
src
->
template
data
<
T
>();
const
auto
*
src_data
=
src
->
template
data
<
T
>();
auto
*
dst
=
ctx
.
template
Output
<
Tensor
>(
"Out"
);
auto
*
dst
=
ctx
.
template
Output
<
Tensor
>(
"Out"
);
const
T
*
dst_data
=
dst
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
T
*
dst_data
=
dst
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
// get memory dim
// get memory dim
PADDLE_ENFORCE
(
src
->
dims
().
size
()
==
2
||
src
->
dims
().
size
()
==
4
,
PADDLE_ENFORCE
(
src
->
dims
().
size
()
==
2
||
src
->
dims
().
size
()
==
4
,
"Input dim must be with 2 or 4"
);
"Input dim must be with 2 or 4"
);
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
src
->
dims
());
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
src
->
dims
());
// create memory description
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
auto
data_md
=
src_tz
.
size
()
==
2
const
std
::
string
key_src_data
=
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
key
+
ctx
.
op
().
Output
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
mkldnn
::
memory
::
format
::
nc
)
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
:
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
const
std
::
string
key_dst_mem
=
key
+
"@eltwise_fwd_dst_mem"
;
mkldnn
::
memory
::
format
::
nchw
);
const
std
::
string
key_fwd
=
key
+
"@eltwise_fwd"
;
// create memory primitives
auto
p_fwd
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
auto
src_memory
=
dev_ctx
.
GetBlob
(
key_fwd
));
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
float
*>
(
src_data
)));
// save input data to be referred in backward path
auto
dst_memory
=
auto
p_src_data
=
std
::
make_shared
<
const
T
*>
(
src_data
);
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
dev_ctx
.
SetBlob
(
key_src_data
,
p_src_data
);
static_cast
<
void
*>
(
const_cast
<
float
*>
(
dst_data
)));
if
(
p_fwd
==
nullptr
)
{
auto
forward_desc
=
mkldnn
::
eltwise_forward
::
desc
(
// create memory description
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
data_md
,
alpha
,
beta
);
auto
data_md
=
src_tz
.
size
()
==
2
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
// save prim desc into global device context to be referred in backward path
mkldnn
::
memory
::
format
::
nc
)
const
std
::
string
key
=
ctx
.
op
().
Output
(
"Out"
);
:
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
const
std
::
string
key_eltwise_pd
=
key
+
"@eltwise_pd"
;
mkldnn
::
memory
::
format
::
nchw
);
auto
forward_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
forward_desc
,
mkldnn_engine
);
// create memory primitives
dev_ctx
.
SetBlob
(
key_eltwise_pd
,
forward_pd
);
auto
p_src_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
src_data
)));
auto
eltwise
=
mkldnn
::
eltwise_forward
(
*
forward_pd
,
src_memory
,
dst_memory
);
dev_ctx
.
SetBlob
(
key_src_mem
,
p_src_mem
);
auto
p_dst_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
dst_data
)));
dev_ctx
.
SetBlob
(
key_dst_mem
,
p_dst_mem
);
auto
fwd_desc
=
mkldnn
::
eltwise_forward
::
desc
(
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
data_md
,
alpha
,
beta
);
auto
p_fwd_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
fwd_desc
,
mkldnn_engine
);
const
std
::
string
key_fwd_pd
=
key
+
"eltwise_fwd_pd"
;
dev_ctx
.
SetBlob
(
key_fwd_pd
,
p_fwd_pd
);
p_fwd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
*
p_fwd_pd
,
*
(
p_src_mem
.
get
()),
*
(
p_dst_mem
.
get
()));
dev_ctx
.
SetBlob
(
key_fwd
,
p_fwd
);
}
else
{
// primitives already exist
auto
p_src_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
PADDLE_ENFORCE
(
p_src_mem
!=
nullptr
,
"Fail to find eltwise p_src_mem in device context."
);
auto
p_dst_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_dst_mem
));
PADDLE_ENFORCE
(
p_dst_mem
!=
nullptr
,
"Fail to find eltwise p_src_mem in device context."
);
p_src_mem
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
src_data
));
p_dst_mem
->
set_data_handle
(
dst_data
);
}
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
eltwise
};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
*
(
p_fwd
.
get
())
};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
}
...
@@ -83,8 +124,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -83,8 +124,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
// get buffers
// get buffers
const
auto
*
x
=
ctx
.
template
Input
<
Tensor
>(
"X"
);
const
auto
*
out
=
ctx
.
template
Input
<
Tensor
>(
"Out"
);
const
auto
*
src
=
x
->
template
data
<
T
>();
auto
*
dout
=
ctx
.
template
Input
<
Tensor
>(
framework
::
GradVarName
(
"Out"
));
auto
*
dout
=
ctx
.
template
Input
<
Tensor
>(
framework
::
GradVarName
(
"Out"
));
const
auto
*
diff_dst
=
dout
->
template
data
<
T
>();
const
auto
*
diff_dst
=
dout
->
template
data
<
T
>();
...
@@ -94,45 +134,73 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -94,45 +134,73 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const
T
*
diff_src
=
dx
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
const
T
*
diff_src
=
dx
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
// get memory dim
// get memory dim
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
x
->
dims
());
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
out
->
dims
());
// create memory description
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
auto
data_md
=
src_tz
.
size
()
==
2
const
std
::
string
key_diff_src_mem
=
key
+
"@eltwise_diff_src_mem"
;
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
const
std
::
string
key_diff_dst_mem
=
key
+
"@eltwise_diff_dst_mem"
;
mkldnn
::
memory
::
format
::
nc
)
const
std
::
string
key_grad
=
key
+
"@eltwise_grad"
;
:
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
mkldnn
::
memory
::
format
::
nchw
);
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
// create memory primitives
const
auto
p_src_data
=
auto
src_memory
=
mkldnn
::
memory
(
std
::
static_pointer_cast
<
T
*>
(
dev_ctx
.
GetBlob
(
key_src_data
));
{
data_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
float
*>
(
src
)));
auto
diff_src_memory
=
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
auto
p_src_mem
=
static_cast
<
void
*>
(
const_cast
<
float
*>
(
diff_src
)));
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
auto
diff_dst_memory
=
p_src_mem
->
set_data_handle
(
*
p_src_data
.
get
());
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
float
*>
(
diff_dst
)));
auto
p_grad
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive
>
(
dev_ctx
.
GetBlob
(
key_grad
));
auto
backward_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
data_md
,
data_md
,
alpha
,
beta
);
if
(
p_grad
==
nullptr
)
{
// create memory description
// retrieve eltwise primitive desc from device context
auto
data_md
=
src_tz
.
size
()
==
2
const
std
::
string
key
=
ctx
.
op
().
Input
(
"Out"
);
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
const
std
::
string
key_eltwise_pd
=
key
+
"@eltwise_pd"
;
mkldnn
::
memory
::
format
::
nc
)
const
std
::
shared_ptr
<
void
>
forward_pd
=
dev_ctx
.
GetBlob
(
key_eltwise_pd
);
:
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
PADDLE_ENFORCE
(
forward_pd
!=
nullptr
,
mkldnn
::
memory
::
format
::
nchw
);
"Fail to find eltwise_pd in device context"
);
auto
*
p_forward_pd
=
// create memory primitives
static_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
*>
(
forward_pd
.
get
());
std
::
shared_ptr
<
void
>
p_diff_src_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
auto
eltwise_bwd_prim_desc
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
diff_src
)));
backward_desc
,
mkldnn_engine
,
*
p_forward_pd
);
dev_ctx
.
SetBlob
(
key_diff_src_mem
,
p_diff_src_mem
);
std
::
shared_ptr
<
void
>
p_diff_dst_mem
=
auto
eltwise_bwd
=
mkldnn
::
eltwise_backward
(
eltwise_bwd_prim_desc
,
src_memory
,
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
diff_dst_memory
,
diff_src_memory
);
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
diff_dst
)));
dev_ctx
.
SetBlob
(
key_diff_dst_mem
,
p_diff_dst_mem
);
auto
bwd_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
data_md
,
data_md
,
alpha
,
beta
);
const
std
::
string
key_fwd_pd
=
key
+
"eltwise_fwd_pd"
;
auto
*
p_fwd_pd
=
static_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
*>
(
dev_ctx
.
GetBlob
(
key_fwd_pd
).
get
());
auto
eltwise_bwd_prim_desc
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
bwd_desc
,
mkldnn_engine
,
*
p_fwd_pd
);
p_grad
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
eltwise_bwd_prim_desc
,
*
static_cast
<
mkldnn
::
memory
*>
(
p_src_mem
.
get
()),
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_diff_dst_mem
.
get
())),
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_diff_src_mem
.
get
())));
}
else
{
// primitives already exist
auto
p_diff_src_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_src_mem
));
auto
p_diff_dst_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_dst_mem
));
p_diff_src_mem
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_src
));
p_diff_dst_mem
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_dst
));
}
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
eltwise_bwd
};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
*
(
p_grad
.
get
())
};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
}
}
// anonymous namespace
}
// anonymous namespace
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
7205d331
...
@@ -41,7 +41,7 @@ namespace operators {
...
@@ -41,7 +41,7 @@ namespace operators {
\
\
protected: \
protected: \
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
auto
*
op = new ::paddle::framework::OpDesc(); \
auto
*
op = new ::paddle::framework::OpDesc(); \
op->SetType(#KERNEL_TYPE "_grad"); \
op->SetType(#KERNEL_TYPE "_grad"); \
op->SetInput("Out", Output("Out")); \
op->SetInput("Out", Output("Out")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
op->SetInput(::paddle::framework::GradVarName("Out"), \
...
@@ -54,23 +54,50 @@ namespace operators {
...
@@ -54,23 +54,50 @@ namespace operators {
} \
} \
}
}
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
OperatorWithKernel
&
oper
,
const
std
::
string
&
name
)
{
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_MKLDNN
auto
it
=
oper
.
Attrs
().
find
(
"use_mkldnn"
);
if
(
library
==
framework
::
LibraryType
::
kPlain
&&
it
!=
oper
.
Attrs
().
end
()
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kMKLDNN
;
}
#endif
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
name
)
->
type
()),
ctx
.
GetPlace
(),
layout
,
library
);
}
class
ActivationOp
:
public
framework
::
OperatorWithKernel
{
class
ActivationOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
GetKernelType
(
ctx
,
*
this
,
"X"
);
}
};
};
class
ActivationOpGrad
:
public
framework
::
OperatorWithKernel
{
class
ActivationOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
GetKernelType
(
ctx
,
*
this
,
"Out"
);
}
};
};
__attribute__
((
unused
))
constexpr
char
SigmoidDoc
[]
=
R"DOC(
__attribute__
((
unused
))
constexpr
char
SigmoidDoc
[]
=
R"DOC(
...
...
paddle/fluid/operators/mkldnn_activation_op.h
浏览文件 @
7205d331
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
...
@@ -60,52 +62,5 @@ class MKLDNNActivationGradKernel
...
@@ -60,52 +62,5 @@ class MKLDNNActivationGradKernel
}
}
};
};
namespace
{
// NOLINT
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
OperatorWithKernel
&
oper
)
{
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_MKLDNN
if
(
library
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kMKLDNN
;
}
#endif
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
layout
,
library
);
}
}
// anonymous namespace
class
ActivationWithMKLDNNOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
GetKernelType
(
ctx
,
*
this
);
}
};
class
ActivationWithMKLDNNOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
GetKernelType
(
ctx
,
*
this
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
7205d331
...
@@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) {
...
@@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) {
return
static_cast
<
void
*>
(
const_cast
<
Type
*>
(
t
));
return
static_cast
<
void
*>
(
const_cast
<
Type
*>
(
t
));
}
}
template
<
typename
Type
>
void
*
to_void_reinterpret_cast
(
const
Type
*
t
)
{
return
reinterpret_cast
<
void
*>
(
const_cast
<
Type
*>
(
t
));
}
template
<
class
Type
>
template
<
class
Type
>
using
tf_desc
=
typename
Type
::
desc
;
using
tf_desc
=
typename
Type
::
desc
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录