Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8c463700
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8c463700
编写于
4月 02, 2020
作者:
J
joanna.wozna.intel
提交者:
GitHub
4月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add default pass attributes (#23042)
上级
48144e40
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
119 addition
and
4 deletion
+119
-4
paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc
.../fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc
+5
-2
paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc
...framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc
+24
-0
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+42
-2
paddle/fluid/framework/ir/pass_test.cc
paddle/fluid/framework/ir/pass_test.cc
+48
-0
未找到文件。
paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc
浏览文件 @
8c463700
...
...
@@ -51,6 +51,9 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS
(
cpu_quantize_placement_pass
,
paddle
::
framework
::
ir
::
CPUQuantizePlacementPass
)
// a vector of operator type names to be quantized ("conv2d" etc.)
.
RequirePassAttr
(
"quantize_enabled_op_types"
)
// the second param is the default value for this vector
.
DefaultPassAttr
(
"quantize_enabled_op_types"
,
new
std
::
unordered_set
<
std
::
string
>
())
// a vector of operator ids that are to be excluded from quantization
.
RequirePassAttr
(
"quantize_excluded_op_ids"
);
// the second param is the default value for this vector
.
DefaultPassAttr
(
"quantize_excluded_op_ids"
,
new
std
::
unordered_set
<
int
>
());
paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc
浏览文件 @
8c463700
...
...
@@ -111,6 +111,25 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types,
EXPECT_EQ
(
use_quantizer_true_count
,
expected_use_quantizer_true_count
);
}
void
DefaultAttrTest
(
unsigned
expected_use_quantizer_true_count
)
{
auto
prog
=
BuildProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"cpu_quantize_placement_pass"
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
unsigned
use_quantizer_true_count
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op
=
node
->
Op
();
if
(
op
->
HasAttr
(
"use_quantizer"
)
&&
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"use_quantizer"
)))
{
++
use_quantizer_true_count
;
}
}
}
EXPECT_EQ
(
use_quantizer_true_count
,
expected_use_quantizer_true_count
);
}
TEST
(
QuantizerPlacementPass
,
enabled_pool
)
{
MainTest
({
"pool2d"
},
{},
2
);
}
TEST
(
QuantizerPlacementPass
,
enabled_conv_excluded_one
)
{
...
...
@@ -122,6 +141,11 @@ TEST(QuantizerPlacementPass, excluded_none) {
MainTest
({},
{},
4
);
}
TEST
(
QuantizerPlacementPass
,
default_attr_value
)
{
// 2 conv + 2 pool
DefaultAttrTest
(
4
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
8c463700
...
...
@@ -100,8 +100,14 @@ class Pass {
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the pass"
,
attr_name
);
if
(
default_pass_attrs_
.
count
(
attr_name
)
==
0
)
{
PADDLE_ENFORCE_EQ
(
attrs_
.
count
(
attr_name
),
0
,
platform
::
errors
::
InvalidArgument
(
"Attribute %s already set in the pass"
,
attr_name
));
}
else
{
VLOG
(
3
)
<<
"Setting the attribute "
<<
attr_name
<<
" for the pass "
<<
type_
;
}
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
...
...
@@ -140,11 +146,21 @@ class Pass {
required_graph_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
// Pass doesn't take ownership. PassRegistrar should delete default_attrs
void
RegisterDefaultPassAttrs
(
std
::
map
<
std
::
string
,
boost
::
any
>
default_attr_values
)
{
for
(
auto
const
&
attr_name
:
default_attr_values
)
{
default_pass_attrs_
.
insert
(
attr_name
.
first
);
}
attrs_
.
insert
(
default_attr_values
.
begin
(),
default_attr_values
.
end
());
}
void
RegisterType
(
const
std
::
string
&
type
)
{
type_
=
type
;
}
mutable
bool
applied_
{
false
};
std
::
string
type_
;
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
default_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
...
...
@@ -203,16 +219,38 @@ struct PassRegistrar : public Registrar {
std
::
unique_ptr
<
Pass
>
pass
(
new
PassType
());
pass
->
RegisterRequiredPassAttrs
(
this
->
required_pass_attrs_
);
pass
->
RegisterRequiredGraphAttrs
(
this
->
required_graph_attrs_
);
pass
->
RegisterDefaultPassAttrs
(
this
->
default_attr_values_
);
pass
->
RegisterType
(
pass_type
);
return
pass
;
});
}
~
PassRegistrar
()
{
for
(
auto
&
attr
:
default_attr_values_
)
{
if
(
default_attr_dels_
.
find
(
attr
.
first
)
!=
default_attr_dels_
.
end
())
{
default_attr_dels_
[
attr
.
first
]();
}
}
default_attr_values_
.
clear
();
default_attr_dels_
.
clear
();
}
PassRegistrar
<
PassType
>
&
RequirePassAttr
(
const
std
::
string
&
attr
)
{
required_pass_attrs_
.
insert
(
attr
);
return
*
this
;
}
// PassRegistrar takes ownership of default_attr_value
template
<
typename
AttrType
>
PassRegistrar
<
PassType
>
&
DefaultPassAttr
(
const
std
::
string
&
attr
,
AttrType
&&
default_attr_value
)
{
default_attr_values_
[
attr
]
=
default_attr_value
;
default_attr_dels_
[
attr
]
=
[
default_attr_value
,
attr
]()
{
delete
default_attr_value
;
};
return
*
this
;
}
PassRegistrar
<
PassType
>
&
RequireGraphAttr
(
const
std
::
string
&
attr
)
{
required_graph_attrs_
.
insert
(
attr
);
return
*
this
;
...
...
@@ -221,6 +259,8 @@ struct PassRegistrar : public Registrar {
private:
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
default_attr_values_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
default_attr_dels_
;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
...
...
paddle/fluid/framework/ir/pass_test.cc
浏览文件 @
8c463700
...
...
@@ -120,6 +120,50 @@ TEST(PassTest, TestPassAttrCheck) {
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"shouldn't have cycle"
)
!=
exception
.
npos
);
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass"
);
pass
->
Set
<
int
>
(
"test_pass_attr"
,
new
int
);
try
{
pass
->
Set
<
int
>
(
"test_pass_attr"
,
new
int
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
&
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"Attribute test_pass_attr already set in the pass"
)
!=
exception
.
npos
);
}
class
TestPassWithDefault
:
public
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
graph
->
Set
<
int
>
(
"copy_default_attr"
,
new
int
);
int
test_pass_attr
=
this
->
Get
<
int
>
(
"default_attr"
);
graph
->
Get
<
int
>
(
"copy_default_attr"
)
=
test_pass_attr
+
1
;
}
};
TEST
(
PassTest
,
TestPassDefaultAttrCheck
)
{
ProgramDesc
prog
;
// check if default value is set
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass_default_attr"
);
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
(
prog
));
ASSERT_EQ
(
pass
->
Get
<
int
>
(
"default_attr"
),
1
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
ASSERT_EQ
(
graph
->
Get
<
int
>
(
"copy_default_attr"
),
2
);
// check if new value overrides default value
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass_default_attr"
);
pass
->
Set
<
int
>
(
"default_attr"
,
new
int
{
3
});
ASSERT_EQ
(
pass
->
Get
<
int
>
(
"default_attr"
),
3
);
}
TEST
(
PassTest
,
TestPassRegistrarDeconstructor
)
{
auto
pass_registrary
=
new
PassRegistrar
<
paddle
::
framework
::
ir
::
TestPassWithDefault
>
(
"test_deconstructor"
);
pass_registrary
->
DefaultPassAttr
(
"deconstructor_attr"
,
new
int
{
1
});
pass_registrary
->~
PassRegistrar
();
}
}
// namespace ir
...
...
@@ -129,3 +173,7 @@ TEST(PassTest, TestPassAttrCheck) {
REGISTER_PASS
(
test_pass
,
paddle
::
framework
::
ir
::
TestPass
)
.
RequirePassAttr
(
"test_pass_attr"
)
.
RequireGraphAttr
(
"test_graph_attr"
);
REGISTER_PASS
(
test_pass_default_attr
,
paddle
::
framework
::
ir
::
TestPassWithDefault
)
.
DefaultPassAttr
(
"default_attr"
,
new
int
{
1
});
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录