Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8c463700
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8c463700
编写于
4月 02, 2020
作者:
J
joanna.wozna.intel
提交者:
GitHub
4月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add default pass attributes (#23042)
上级
48144e40
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
119 addition
and
4 deletion
+119
-4
paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc
.../fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc
+5
-2
paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc
...framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc
+24
-0
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+42
-2
paddle/fluid/framework/ir/pass_test.cc
paddle/fluid/framework/ir/pass_test.cc
+48
-0
未找到文件。
paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc
浏览文件 @
8c463700
...
...
@@ -51,6 +51,9 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS
(
cpu_quantize_placement_pass
,
paddle
::
framework
::
ir
::
CPUQuantizePlacementPass
)
// a vector of operator type names to be quantized ("conv2d" etc.)
.
RequirePassAttr
(
"quantize_enabled_op_types"
)
// the second param is the default value for this vector
.
DefaultPassAttr
(
"quantize_enabled_op_types"
,
new
std
::
unordered_set
<
std
::
string
>
())
// a vector of operator ids that are to be excluded from quantization
.
RequirePassAttr
(
"quantize_excluded_op_ids"
);
// the second param is the default value for this vector
.
DefaultPassAttr
(
"quantize_excluded_op_ids"
,
new
std
::
unordered_set
<
int
>
());
paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc
浏览文件 @
8c463700
...
...
@@ -111,6 +111,25 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types,
EXPECT_EQ
(
use_quantizer_true_count
,
expected_use_quantizer_true_count
);
}
void
DefaultAttrTest
(
unsigned
expected_use_quantizer_true_count
)
{
auto
prog
=
BuildProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"cpu_quantize_placement_pass"
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
unsigned
use_quantizer_true_count
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op
=
node
->
Op
();
if
(
op
->
HasAttr
(
"use_quantizer"
)
&&
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"use_quantizer"
)))
{
++
use_quantizer_true_count
;
}
}
}
EXPECT_EQ
(
use_quantizer_true_count
,
expected_use_quantizer_true_count
);
}
TEST
(
QuantizerPlacementPass
,
enabled_pool
)
{
MainTest
({
"pool2d"
},
{},
2
);
}
TEST
(
QuantizerPlacementPass
,
enabled_conv_excluded_one
)
{
...
...
@@ -122,6 +141,11 @@ TEST(QuantizerPlacementPass, excluded_none) {
MainTest
({},
{},
4
);
}
TEST
(
QuantizerPlacementPass
,
default_attr_value
)
{
// 2 conv + 2 pool
DefaultAttrTest
(
4
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
8c463700
...
...
@@ -100,8 +100,14 @@ class Pass {
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the pass"
,
attr_name
);
if
(
default_pass_attrs_
.
count
(
attr_name
)
==
0
)
{
PADDLE_ENFORCE_EQ
(
attrs_
.
count
(
attr_name
),
0
,
platform
::
errors
::
InvalidArgument
(
"Attribute %s already set in the pass"
,
attr_name
));
}
else
{
VLOG
(
3
)
<<
"Setting the attribute "
<<
attr_name
<<
" for the pass "
<<
type_
;
}
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
...
...
@@ -140,11 +146,21 @@ class Pass {
required_graph_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
// Pass doesn't take ownership. PassRegistrar should delete default_attrs
void
RegisterDefaultPassAttrs
(
std
::
map
<
std
::
string
,
boost
::
any
>
default_attr_values
)
{
for
(
auto
const
&
attr_name
:
default_attr_values
)
{
default_pass_attrs_
.
insert
(
attr_name
.
first
);
}
attrs_
.
insert
(
default_attr_values
.
begin
(),
default_attr_values
.
end
());
}
void
RegisterType
(
const
std
::
string
&
type
)
{
type_
=
type
;
}
mutable
bool
applied_
{
false
};
std
::
string
type_
;
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
default_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
...
...
@@ -203,16 +219,38 @@ struct PassRegistrar : public Registrar {
std
::
unique_ptr
<
Pass
>
pass
(
new
PassType
());
pass
->
RegisterRequiredPassAttrs
(
this
->
required_pass_attrs_
);
pass
->
RegisterRequiredGraphAttrs
(
this
->
required_graph_attrs_
);
pass
->
RegisterDefaultPassAttrs
(
this
->
default_attr_values_
);
pass
->
RegisterType
(
pass_type
);
return
pass
;
});
}
~
PassRegistrar
()
{
for
(
auto
&
attr
:
default_attr_values_
)
{
if
(
default_attr_dels_
.
find
(
attr
.
first
)
!=
default_attr_dels_
.
end
())
{
default_attr_dels_
[
attr
.
first
]();
}
}
default_attr_values_
.
clear
();
default_attr_dels_
.
clear
();
}
PassRegistrar
<
PassType
>
&
RequirePassAttr
(
const
std
::
string
&
attr
)
{
required_pass_attrs_
.
insert
(
attr
);
return
*
this
;
}
// PassRegistrar takes ownership of default_attr_value
template
<
typename
AttrType
>
PassRegistrar
<
PassType
>
&
DefaultPassAttr
(
const
std
::
string
&
attr
,
AttrType
&&
default_attr_value
)
{
default_attr_values_
[
attr
]
=
default_attr_value
;
default_attr_dels_
[
attr
]
=
[
default_attr_value
,
attr
]()
{
delete
default_attr_value
;
};
return
*
this
;
}
PassRegistrar
<
PassType
>
&
RequireGraphAttr
(
const
std
::
string
&
attr
)
{
required_graph_attrs_
.
insert
(
attr
);
return
*
this
;
...
...
@@ -221,6 +259,8 @@ struct PassRegistrar : public Registrar {
private:
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
default_attr_values_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
default_attr_dels_
;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
...
...
paddle/fluid/framework/ir/pass_test.cc
浏览文件 @
8c463700
...
...
@@ -120,6 +120,50 @@ TEST(PassTest, TestPassAttrCheck) {
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"shouldn't have cycle"
)
!=
exception
.
npos
);
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass"
);
pass
->
Set
<
int
>
(
"test_pass_attr"
,
new
int
);
try
{
pass
->
Set
<
int
>
(
"test_pass_attr"
,
new
int
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
&
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"Attribute test_pass_attr already set in the pass"
)
!=
exception
.
npos
);
}
class
TestPassWithDefault
:
public
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
graph
->
Set
<
int
>
(
"copy_default_attr"
,
new
int
);
int
test_pass_attr
=
this
->
Get
<
int
>
(
"default_attr"
);
graph
->
Get
<
int
>
(
"copy_default_attr"
)
=
test_pass_attr
+
1
;
}
};
TEST
(
PassTest
,
TestPassDefaultAttrCheck
)
{
ProgramDesc
prog
;
// check if default value is set
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass_default_attr"
);
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
(
prog
));
ASSERT_EQ
(
pass
->
Get
<
int
>
(
"default_attr"
),
1
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
ASSERT_EQ
(
graph
->
Get
<
int
>
(
"copy_default_attr"
),
2
);
// check if new value overrides default value
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass_default_attr"
);
pass
->
Set
<
int
>
(
"default_attr"
,
new
int
{
3
});
ASSERT_EQ
(
pass
->
Get
<
int
>
(
"default_attr"
),
3
);
}
TEST
(
PassTest
,
TestPassRegistrarDeconstructor
)
{
auto
pass_registrary
=
new
PassRegistrar
<
paddle
::
framework
::
ir
::
TestPassWithDefault
>
(
"test_deconstructor"
);
pass_registrary
->
DefaultPassAttr
(
"deconstructor_attr"
,
new
int
{
1
});
pass_registrary
->~
PassRegistrar
();
}
}
// namespace ir
...
...
@@ -129,3 +173,7 @@ TEST(PassTest, TestPassAttrCheck) {
REGISTER_PASS
(
test_pass
,
paddle
::
framework
::
ir
::
TestPass
)
.
RequirePassAttr
(
"test_pass_attr"
)
.
RequireGraphAttr
(
"test_graph_attr"
);
REGISTER_PASS
(
test_pass_default_attr
,
paddle
::
framework
::
ir
::
TestPassWithDefault
)
.
DefaultPassAttr
(
"default_attr"
,
new
int
{
1
});
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录