Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d5a888e1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d5a888e1
编写于
2月 27, 2019
作者:
T
Tao Luo
提交者:
GitHub
2月 27, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15943 from kbinias/kbinias/add-placement-pass-tester
MKL-DNN: Add placement pass tester
上级
ba90e052
72253391
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
138 addition
and
1 deletion
+138
-1
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc
paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc
+1
-1
paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc
...fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc
+136
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
d5a888e1
...
...
@@ -105,4 +105,5 @@ if (WITH_MKLDNN)
cc_test
(
test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor
)
cc_test
(
test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
cc_test
(
test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass
)
endif
()
paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc
浏览文件 @
d5a888e1
...
...
@@ -21,7 +21,7 @@ namespace ir {
std
::
unique_ptr
<
ir
::
Graph
>
MKLDNNPlacementPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
VLOG
(
3
)
<<
"Aplies MKL-DNN placement strategy."
;
VLOG
(
3
)
<<
"Ap
p
lies MKL-DNN placement strategy."
;
const
auto
&
op_types_list
=
Get
<
std
::
unordered_set
<
std
::
string
>>
(
"mkldnn_enabled_op_types"
);
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
...
...
paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc
0 → 100644
浏览文件 @
d5a888e1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
boost
::
tribool
use_mkldnn
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
if
(
!
boost
::
indeterminate
(
use_mkldnn
))
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
if
(
type
==
"conv2d"
)
{
op
->
SetAttr
(
"name"
,
name
);
op
->
SetInput
(
"Input"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Filter"
,
{
inputs
[
1
]});
op
->
SetInput
(
"Bias"
,
{
inputs
[
2
]});
}
else
if
(
type
==
"relu"
)
{
op
->
SetInput
(
"X"
,
inputs
);
}
else
if
(
type
==
"concat"
)
{
op
->
SetAttr
(
"axis"
,
1
);
op
->
SetInput
(
"X"
,
{
inputs
[
0
],
inputs
[
1
]});
}
else
if
(
type
==
"pool2d"
)
{
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
}
else
{
FAIL
()
<<
"Unexpected operator type."
;
}
op
->
SetOutput
(
"Out"
,
{
outputs
[
0
]});
}
// operator use_mkldnn
// ---------------------------------------
// (a,b)->concat->c none
// (c,weights,bias)->conv->f none
// f->relu->g false
// g->pool->h false
// (h,weights2,bias2)->conv->k true
// k->relu->l true
ProgramDesc
BuildProgramDesc
()
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"c"
,
"weights"
,
"bias"
,
"f"
,
"g"
,
"h"
,
"weights2"
,
"bias2"
,
"k"
,
"l"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
if
(
v
==
"weights"
||
v
==
"bias"
)
{
var
->
SetPersistable
(
true
);
}
}
SetOp
(
&
prog
,
"concat"
,
"concat1"
,
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
}),
std
::
vector
<
std
::
string
>
({
"c"
}),
boost
::
indeterminate
);
SetOp
(
&
prog
,
"conv2d"
,
"conv1"
,
std
::
vector
<
std
::
string
>
({
"c"
,
"weights"
,
"bias"
}),
std
::
vector
<
std
::
string
>
({
"f"
}),
boost
::
indeterminate
);
SetOp
(
&
prog
,
"relu"
,
"relu1"
,
std
::
vector
<
std
::
string
>
({
"f"
}),
std
::
vector
<
std
::
string
>
({
"g"
}),
false
);
SetOp
(
&
prog
,
"pool2d"
,
"pool1"
,
std
::
vector
<
std
::
string
>
({
"g"
}),
std
::
vector
<
std
::
string
>
({
"h"
}),
false
);
SetOp
(
&
prog
,
"conv2d"
,
"conv2"
,
std
::
vector
<
std
::
string
>
({
"h"
,
"weights2"
,
"bias2"
}),
std
::
vector
<
std
::
string
>
({
"k"
}),
true
);
SetOp
(
&
prog
,
"relu"
,
"relu2"
,
std
::
vector
<
std
::
string
>
({
"k"
}),
std
::
vector
<
std
::
string
>
({
"l"
}),
true
);
return
prog
;
}
void
MainTest
(
std
::
initializer_list
<
std
::
string
>
mkldnn_enabled_op_types
,
unsigned
expected_use_mkldnn_true_count
)
{
auto
prog
=
BuildProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"mkldnn_placement_pass"
);
pass
->
Set
(
"mkldnn_enabled_op_types"
,
new
std
::
unordered_set
<
std
::
string
>
(
mkldnn_enabled_op_types
));
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
unsigned
use_mkldnn_true_count
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op
=
node
->
Op
();
if
(
op
->
HasAttr
(
"use_mkldnn"
)
&&
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"use_mkldnn"
)))
{
++
use_mkldnn_true_count
;
}
}
}
EXPECT_EQ
(
use_mkldnn_true_count
,
expected_use_mkldnn_true_count
);
}
TEST
(
MKLDNNPlacementPass
,
enable_conv_relu
)
{
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
MainTest
({
"conv2d"
,
"relu"
},
3
);
}
TEST
(
MKLDNNPlacementPass
,
enable_relu_pool
)
{
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest
({
"relu"
,
"pool2d"
},
4
);
}
TEST
(
MKLDNNPlacementPass
,
enable_all
)
{
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest
({},
4
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
mkldnn_placement_pass
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录