Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
890c7315
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
890c7315
编写于
6月 16, 2022
作者:
津
津
提交者:
GitHub
6月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[inference]add unary trt convert (#43509)
* add unary
上级
1ec626b1
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
168 addition
and
7 deletion
+168
-7
paddle/fluid/inference/tensorrt/convert/unary_op.cc
paddle/fluid/inference/tensorrt/convert/unary_op.cc
+97
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+61
-2
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py
...id/tests/unittests/ir/inference/test_trt_convert_unary.py
+10
-5
未找到文件。
paddle/fluid/inference/tensorrt/convert/unary_op.cc
浏览文件 @
890c7315
...
@@ -66,6 +66,23 @@ const std::unordered_map<std::string, nvinfer1::UnaryOperation>
...
@@ -66,6 +66,23 @@ const std::unordered_map<std::string, nvinfer1::UnaryOperation>
UnaryOpConverter
::
ops
=
{
UnaryOpConverter
::
ops
=
{
{
"exp"
,
nvinfer1
::
UnaryOperation
::
kEXP
},
{
"exp"
,
nvinfer1
::
UnaryOperation
::
kEXP
},
{
"log"
,
nvinfer1
::
UnaryOperation
::
kLOG
},
{
"log"
,
nvinfer1
::
UnaryOperation
::
kLOG
},
{
"sqrt"
,
nvinfer1
::
UnaryOperation
::
kSQRT
},
{
"abs"
,
nvinfer1
::
UnaryOperation
::
kABS
},
{
"sin"
,
nvinfer1
::
UnaryOperation
::
kSIN
},
{
"cos"
,
nvinfer1
::
UnaryOperation
::
kCOS
},
{
"tan"
,
nvinfer1
::
UnaryOperation
::
kTAN
},
{
"sinh"
,
nvinfer1
::
UnaryOperation
::
kSINH
},
{
"cosh"
,
nvinfer1
::
UnaryOperation
::
kCOSH
},
{
"asin"
,
nvinfer1
::
UnaryOperation
::
kASIN
},
{
"acos"
,
nvinfer1
::
UnaryOperation
::
kACOS
},
{
"atan"
,
nvinfer1
::
UnaryOperation
::
kATAN
},
{
"asinh"
,
nvinfer1
::
UnaryOperation
::
kASINH
},
{
"atanh"
,
nvinfer1
::
UnaryOperation
::
kATANH
},
{
"ceil"
,
nvinfer1
::
UnaryOperation
::
kCEIL
},
{
"floor"
,
nvinfer1
::
UnaryOperation
::
kFLOOR
},
#if IS_TRT_VERSION_GE(7000)
{
"erf"
,
nvinfer1
::
UnaryOperation
::
kERF
},
#endif
};
};
class
ExpOpConverter
:
public
UnaryOpConverter
{
class
ExpOpConverter
:
public
UnaryOpConverter
{
...
@@ -78,9 +95,89 @@ class LogOpConverter : public UnaryOpConverter {
...
@@ -78,9 +95,89 @@ class LogOpConverter : public UnaryOpConverter {
LogOpConverter
()
{
op_type_
=
"log"
;
}
LogOpConverter
()
{
op_type_
=
"log"
;
}
};
};
class
SqrtOpConverter
:
public
UnaryOpConverter
{
public:
SqrtOpConverter
()
{
op_type_
=
"sqrt"
;
}
};
class
AbsOpConverter
:
public
UnaryOpConverter
{
public:
AbsOpConverter
()
{
op_type_
=
"abs"
;
}
};
class
SinOpConverter
:
public
UnaryOpConverter
{
public:
SinOpConverter
()
{
op_type_
=
"sin"
;
}
};
class
CosOpConverter
:
public
UnaryOpConverter
{
public:
CosOpConverter
()
{
op_type_
=
"cos"
;
}
};
class
TanOpConverter
:
public
UnaryOpConverter
{
public:
TanOpConverter
()
{
op_type_
=
"tan"
;
}
};
class
SinhOpConverter
:
public
UnaryOpConverter
{
public:
SinhOpConverter
()
{
op_type_
=
"sinh"
;
}
};
class
CoshOpConverter
:
public
UnaryOpConverter
{
public:
CoshOpConverter
()
{
op_type_
=
"cosh"
;
}
};
class
AsinOpConverter
:
public
UnaryOpConverter
{
public:
AsinOpConverter
()
{
op_type_
=
"asin"
;
}
};
class
AcosOpConverter
:
public
UnaryOpConverter
{
public:
AcosOpConverter
()
{
op_type_
=
"acos"
;
}
};
class
AtanOpConverter
:
public
UnaryOpConverter
{
public:
AtanOpConverter
()
{
op_type_
=
"atan"
;
}
};
class
AsinhOpConverter
:
public
UnaryOpConverter
{
public:
AsinhOpConverter
()
{
op_type_
=
"asinh"
;
}
};
class
AtanhOpConverter
:
public
UnaryOpConverter
{
public:
AtanhOpConverter
()
{
op_type_
=
"atanh"
;
}
};
class
CeilOpConverter
:
public
UnaryOpConverter
{
public:
CeilOpConverter
()
{
op_type_
=
"ceil"
;
}
};
class
FloorOpConverter
:
public
UnaryOpConverter
{
public:
FloorOpConverter
()
{
op_type_
=
"floor"
;
}
};
#if IS_TRT_VERSION_GE(7000)
class
ErfOpConverter
:
public
UnaryOpConverter
{
public:
ErfOpConverter
()
{
op_type_
=
"erf"
;
}
};
#endif
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
exp
,
ExpOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
exp
,
ExpOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
log
,
LogOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
log
,
LogOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
sqrt
,
SqrtOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
abs
,
AbsOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
sin
,
SinOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
cos
,
CosOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
tan
,
TanOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
sinh
,
SinhOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
cosh
,
CoshOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
asin
,
AsinOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
acos
,
AcosOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
atan
,
AtanOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
asinh
,
AsinhOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
atanh
,
AtanhOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
ceil
,
CeilOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
floor
,
FloorOpConverter
);
#if IS_TRT_VERSION_GE(7000)
REGISTER_TRT_OP_CONVERTER
(
erf
,
ErfOpConverter
);
#endif
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
890c7315
...
@@ -75,6 +75,21 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -75,6 +75,21 @@ struct SimpleOpTypeSetTeller : public Teller {
"relu"
,
"relu"
,
"exp"
,
"exp"
,
"log"
,
"log"
,
"sqrt"
,
"abs"
,
"sin"
,
"cos"
,
"tan"
,
"sinh"
,
"cosh"
,
"asin"
,
"acos"
,
"atan"
,
"asinh"
,
"atanh"
,
"ceil"
,
"floor"
,
"erf"
,
"softmax"
,
"softmax"
,
"sigmoid"
,
"sigmoid"
,
"hard_swish"
,
"hard_swish"
,
...
@@ -148,6 +163,21 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -148,6 +163,21 @@ struct SimpleOpTypeSetTeller : public Teller {
"relu"
,
"relu"
,
"exp"
,
"exp"
,
"log"
,
"log"
,
"sqrt"
,
"abs"
,
"sin"
,
"cos"
,
"tan"
,
"sinh"
,
"cosh"
,
"asin"
,
"acos"
,
"atan"
,
"asinh"
,
"atanh"
,
"ceil"
,
"floor"
,
"erf"
,
"softmax"
,
"softmax"
,
"sigmoid"
,
"sigmoid"
,
"hard_swish"
,
"hard_swish"
,
...
@@ -227,8 +257,31 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -227,8 +257,31 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return
false
;
return
false
;
for
(
auto
&
teller
:
tellers_
)
{
for
(
auto
&
teller
:
tellers_
)
{
if
(
op_type
==
"relu"
||
op_type
==
"relu6"
||
op_type
==
"tanh"
||
std
::
unordered_set
<
std
::
string
>
act_op_list
=
{
"relu"
,
op_type
==
"sigmoid"
||
op_type
==
"exp"
||
op_type
==
"log"
)
{
"elu"
,
"selu"
,
"softsign"
,
"softplus"
,
"stanh"
,
"thresholded_relu"
,
"exp"
,
"log"
,
"sqrt"
,
"abs"
,
"sin"
,
"cos"
,
"tan"
,
"sinh"
,
"cosh"
,
"asin"
,
"acos"
,
"atan"
,
"asinh"
,
"atanh"
,
"ceil"
,
"floor"
,
"erf"
};
if
(
act_op_list
.
find
(
op_type
)
!=
act_op_list
.
end
())
{
auto
*
block
=
desc
.
Block
();
auto
*
block
=
desc
.
Block
();
if
(
block
==
nullptr
)
{
if
(
block
==
nullptr
)
{
VLOG
(
3
)
<<
"The block desc is nullptr, we can't continue to analyze. "
VLOG
(
3
)
<<
"The block desc is nullptr, we can't continue to analyze. "
...
@@ -244,6 +297,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -244,6 +297,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<<
" op does not support input's dim is 1 in tensorrt."
;
<<
" op does not support input's dim is 1 in tensorrt."
;
return
false
;
return
false
;
}
}
#if !IS_TRT_VERSION_GE(7000)
if
(
op_type
==
"erf"
)
{
VLOG
(
3
)
<<
op_type
<<
" op does not support tensorrt."
;
return
false
;
}
#endif
}
}
if
(
op_type
==
"pool2d"
)
{
if
(
op_type
==
"pool2d"
)
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py
浏览文件 @
890c7315
...
@@ -27,20 +27,25 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
...
@@ -27,20 +27,25 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
return
True
return
True
def
sample_program_configs
(
self
):
def
sample_program_configs
(
self
):
self
.
trt_param
.
workspace_size
=
1073741824
def
generate_input1
(
dims
,
batch
,
attrs
:
List
[
Dict
[
str
,
Any
]]):
def
generate_input1
(
dims
,
batch
,
attrs
:
List
[
Dict
[
str
,
Any
]]):
if
dims
==
1
:
if
dims
==
1
:
return
np
.
ones
([
32
]).
astype
(
np
.
float32
)
return
np
.
random
.
random
([
32
]).
astype
(
np
.
float32
)
elif
dims
==
2
:
elif
dims
==
2
:
return
np
.
ones
([
3
,
32
]).
astype
(
np
.
float32
)
return
np
.
random
.
random
([
3
,
32
]).
astype
(
np
.
float32
)
elif
dims
==
3
:
elif
dims
==
3
:
return
np
.
ones
([
3
,
32
,
32
]).
astype
(
np
.
float32
)
return
np
.
random
.
random
([
3
,
32
,
32
]).
astype
(
np
.
float32
)
else
:
else
:
return
np
.
ones
([
batch
,
3
,
32
,
32
]).
astype
(
np
.
float32
)
return
np
.
random
.
random
([
batch
,
3
,
32
,
32
]).
astype
(
np
.
float32
)
for
dims
in
[
1
,
2
,
3
,
4
]:
for
dims
in
[
1
,
2
,
3
,
4
]:
for
batch
in
[
1
,
4
]:
for
batch
in
[
1
,
4
]:
for
op_type
in
[
"exp"
,
"log"
]:
for
op_type
in
[
"exp"
,
"log"
,
"sqrt"
,
"abs"
,
"sin"
,
"cos"
,
"tan"
,
"sinh"
,
"cosh"
,
"asin"
,
"acos"
,
"atan"
,
"asinh"
,
"atanh"
,
"ceil"
,
"floor"
]:
self
.
dims
=
dims
self
.
dims
=
dims
dics
=
[{}]
dics
=
[{}]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录