Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
749667e5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
749667e5
编写于
9月 08, 2022
作者:
A
Aurelius84
提交者:
GitHub
9月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute (#45874)
上级
cdda9799
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
43 addition
and
39 deletion
+43
-39
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+43
-31
paddle/fluid/inference/tensorrt/op_teller.h
paddle/fluid/inference/tensorrt/op_teller.h
+0
-8
未找到文件。
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
749667e5
...
...
@@ -303,9 +303,6 @@ bool OpTeller::Tell(const framework::ir::Node* node,
desc
.
HasAttr
(
"skip_quant"
))
return
false
;
// do not support Attribute with Variable(s) Type
if
(
HasUnsupportAttrVar
(
desc
))
return
false
;
for
(
auto
&
teller
:
tellers_
)
{
std
::
unordered_set
<
std
::
string
>
act_op_list
=
{
"relu"
,
"relu6"
,
"sigmoid"
,
...
...
@@ -364,7 +361,30 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
}
if
(
op_type
==
"dropout"
)
{
/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
* If Attribute is Variable(s), HasAttr() will return False
*/
if
(
!
desc
.
HasAttr
(
"dropout_prob"
,
/*with_attr_var=*/
false
))
{
VLOG
(
3
)
<<
"Skip to convert into TRT while found Attribute('dropout_prob') "
"is Variable type in dropout."
;
return
false
;
}
}
if
(
op_type
==
"pool2d"
)
{
// If Attribute is Variable(s), HasAttr() will return False
if
(
!
desc
.
HasAttr
(
"ksize"
,
/*with_attr_var=*/
false
))
{
VLOG
(
3
)
<<
"Skip to convert into TRT while found Attribute('ksize') is "
"Variable type in pool2d."
;
return
false
;
}
std
::
vector
<
int
>
paddings
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"paddings"
));
if
(
paddings
.
size
()
>
2
)
{
...
...
@@ -797,6 +817,12 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
if
(
op_type
==
"arg_max"
)
{
if
(
!
desc
.
HasAttr
(
"axis"
,
/*with_attr_var=*/
false
))
{
VLOG
(
3
)
<<
"Skip to convert into TRT while found Attribute('axis') is "
"Variable type in arg_max."
;
return
false
;
}
int
axis
=
desc
.
HasAttr
(
"axis"
)
?
PADDLE_GET_CONST
(
int64_t
,
desc
.
GetAttr
(
"axis"
))
:
-
1
;
...
...
@@ -1061,6 +1087,13 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
if
(
op_type
==
"squeeze2"
)
{
// If Attribute is Variable(s), HasAttr() will return False
if
(
!
desc
.
HasAttr
(
"axes"
,
/*with_attr_var=*/
false
))
{
VLOG
(
3
)
<<
"Skip to convert into TRT while found Attribute('axes') is "
"Variable type in squeeze2."
;
return
false
;
}
std
::
vector
<
int
>
axes
;
if
(
desc
.
HasAttr
(
"axes"
))
{
axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"axes"
));
...
...
@@ -2002,6 +2035,13 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
if
(
op_type
==
"reduce_sum"
||
op_type
==
"reduce_mean"
)
{
if
(
!
desc
.
HasAttr
(
"dim"
,
/*with_attr_var=*/
false
))
{
VLOG
(
3
)
<<
"Skip to convert into TRT while found Attribute('dim') is "
"Variable type in "
<<
desc
.
Type
();
return
false
;
}
if
(
!
(
desc
.
HasAttr
(
"keep_dim"
)
&&
desc
.
HasAttr
(
"dim"
)
&&
desc
.
HasAttr
(
"reduce_all"
)))
{
VLOG
(
3
)
<<
"the "
<<
op_type
...
...
@@ -2265,34 +2305,6 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return
false
;
}
bool
OpTeller
::
HasUnsupportAttrVar
(
const
framework
::
OpDesc
&
desc
)
const
{
const
std
::
string
op_type
=
desc
.
Type
();
auto
has_attr_var
=
[
&
](
const
std
::
string
&
attr_name
)
->
bool
{
// If Attribute is Variable(s), HasAttr() will return False
return
!
desc
.
HasAttr
(
attr_name
,
/*with_attr_var=*/
false
);
};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
attrs_info
=
{
{
"dropout"
,
{
"dropout_prob"
}},
{
"pool2d"
,
{
"ksize"
}},
{
"arg_max"
,
{
"axis"
}},
{
"reduce_mean"
,
{
"dim"
}},
{
"reduce_sum"
,
{
"dim"
}},
{
"squeeze2"
,
{
"axes"
}},
};
bool
flag
=
false
;
auto
iter
=
attrs_info
.
find
(
op_type
);
if
(
iter
!=
attrs_info
.
end
())
{
for
(
auto
&
attr_name
:
iter
->
second
)
{
if
(
has_attr_var
(
attr_name
))
{
flag
=
true
;
break
;
}
}
}
return
flag
;
}
OpTeller
::
OpTeller
()
{
tellers_
.
emplace_back
(
new
SimpleOpTypeSetTeller
);
}
}
// namespace tensorrt
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/op_teller.h
浏览文件 @
749667e5
...
...
@@ -73,14 +73,6 @@ class OpTeller {
private:
OpTeller
();
/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
*/
bool
HasUnsupportAttrVar
(
const
framework
::
OpDesc
&
desc
)
const
;
private:
std
::
vector
<
std
::
unique_ptr
<
Teller
>>
tellers_
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录