Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2012aeb6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2012aeb6
编写于
4月 02, 2022
作者:
W
Wilber
提交者:
GitHub
4月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add trt pool and ut (#41258)
上级
ad0c106c
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
289 addition
and
119 deletion
+289
-119
paddle/infrt/dialect/tensorrt/convert.h
paddle/infrt/dialect/tensorrt/convert.h
+155
-1
paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td
paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td
+2
-1
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
+0
-27
paddle/infrt/dialect/tensorrt/trt_ops.td
paddle/infrt/dialect/tensorrt/trt_ops.td
+4
-1
paddle/infrt/kernel/tensorrt/trt_helper.h
paddle/infrt/kernel/tensorrt/trt_helper.h
+5
-5
paddle/infrt/kernel/tensorrt/trt_kernels.cc
paddle/infrt/kernel/tensorrt/trt_kernels.cc
+3
-0
paddle/infrt/kernel/tensorrt/trt_layers.h
paddle/infrt/kernel/tensorrt/trt_layers.h
+54
-2
paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir
paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir
+0
-37
paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir
...infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir
+21
-0
paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir
paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir
+24
-45
paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir
paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir
+21
-0
未找到文件。
paddle/infrt/dialect/tensorrt/convert.h
浏览文件 @
2012aeb6
...
@@ -14,17 +14,49 @@
...
@@ -14,17 +14,49 @@
#pragma once
#pragma once
#include <glog/logging.h>
#include <glog/logging.h>
#include <llvm/Support/ErrorHandling.h>
#include <llvm/include/mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/DialectConversion.h>
#include "paddle/infrt/dialect/infrt/common/types.h"
#include "paddle/infrt/dialect/infrt/common/types.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/pd/ir/pd_ops.h"
#include "paddle/infrt/dialect/pd/ir/pd_ops.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/kernel/tensorrt/trt_helper.h"
namespace
infrt
{
namespace
infrt
{
namespace
trt
{
namespace
trt
{
#ifdef INFRT_WITH_TRT
#define STRING_TO_ENUM_TYPE(enum_type) enum_type
#define STRING_TO_ENUM_VALUE(enum_value) enum_value
#include <NvInfer.h>
#else // INFRT_WITH_TRT
#define STRING_TO_ENUM_TYPE(enum_type) std::string
#define STRING_TO_ENUM_VALUE(enum_value) #enum_value
#endif // INFRT_WITH_TRT
template
<
typename
T
>
::
mlir
::
IntegerAttr
createNvinferEnumAttr
(
::
mlir
::
PatternRewriter
&
rewriter
,
// NOLINT
T
enum_value
)
{
return
rewriter
.
getSI32IntegerAttr
((
int32_t
)
enum_value
);
}
template
<
>
::
mlir
::
IntegerAttr
createNvinferEnumAttr
<
std
::
string
>
(
::
mlir
::
PatternRewriter
&
rewriter
,
std
::
string
enum_value
)
{
// NOLINT
(
void
)
enum_value
;
return
rewriter
.
getSI32IntegerAttr
(
-
1
);
}
static
mlir
::
Value
createTRTConv2dOp
(
mlir
::
PatternRewriter
&
rewriter
,
// NOLINT
static
mlir
::
Value
createTRTConv2dOp
(
mlir
::
PatternRewriter
&
rewriter
,
// NOLINT
mlir
::
Operation
*
op
)
{
mlir
::
Operation
*
op
)
{
auto
conv_op
=
::
llvm
::
dyn_cast
<
infrt
::
pd
::
Conv2dOp
>
(
op
);
auto
conv_op
=
::
llvm
::
dyn_cast
<
infrt
::
pd
::
Conv2dOp
>
(
op
);
...
@@ -205,5 +237,127 @@ static mlir::Value createTRTShuffledOp(
...
@@ -205,5 +237,127 @@ static mlir::Value createTRTShuffledOp(
return
rewriter
.
create
<
trt
::
ShuffleOp
>
(
return
rewriter
.
create
<
trt
::
ShuffleOp
>
(
op
->
getLoc
(),
resultTypes
,
operands
,
attributes
);
op
->
getLoc
(),
resultTypes
,
operands
,
attributes
);
}
}
inline
mlir
::
IntegerAttr
CreatePoolingType
(
mlir
::
PatternRewriter
&
builder
,
// NOLINT
mlir
::
StringAttr
pool_type
)
{
// pool_type.
auto
ptype
=
pool_type
.
str
();
if
(
ptype
==
"max"
)
{
return
createNvinferEnumAttr
(
builder
,
nvinfer1
::
PoolingType
::
kMAX
);
}
else
if
(
ptype
==
"avg"
)
{
return
createNvinferEnumAttr
(
builder
,
nvinfer1
::
PoolingType
::
kAVERAGE
);
}
else
{
llvm_unreachable
(
"unknown pool_type."
);
return
{};
}
}
inline
mlir
::
IntegerAttr
CreatePaddingMode
(
mlir
::
PatternRewriter
&
builder
,
// NOLINT
mlir
::
StringAttr
padding_algorithm
,
mlir
::
BoolAttr
ceil_mode
)
{
// TODO(Inference): Phi pool kernel seems not process ceil_mode.
auto
padding_algo
=
padding_algorithm
.
str
();
if
(
padding_algo
==
"SAME"
)
{
return
createNvinferEnumAttr
(
builder
,
nvinfer1
::
PaddingMode
::
kSAME_UPPER
);
}
if
(
ceil_mode
.
getValue
()
&&
padding_algo
!=
"SAME"
)
{
return
createNvinferEnumAttr
(
builder
,
nvinfer1
::
PaddingMode
::
kEXPLICIT_ROUND_UP
);
}
else
{
return
createNvinferEnumAttr
(
builder
,
nvinfer1
::
PaddingMode
::
kEXPLICIT_ROUND_DOWN
);
}
}
inline
::
llvm
::
SmallVector
<::
mlir
::
Value
,
4
>
CreatePaddleTrtPoolingOp
(
mlir
::
PatternRewriter
&
builder
,
// NOLINT
mlir
::
Value
input
,
mlir
::
StringAttr
pool_type
,
mlir
::
ArrayAttr
ksize
,
mlir
::
BoolAttr
global_pooling
,
mlir
::
ArrayAttr
strides
,
mlir
::
ArrayAttr
paddings
,
mlir
::
BoolAttr
exclusive
,
mlir
::
BoolAttr
adaptive
,
mlir
::
BoolAttr
ceil_mode
,
mlir
::
StringAttr
data_format
,
mlir
::
StringAttr
padding_algorithm
)
{
::
llvm
::
SmallVector
<::
mlir
::
Value
,
4
>
tblgen_repl_values
;
// TODO(inference): Support NHWC.
if
(
data_format
.
str
()
!=
"NCHW"
)
{
CHECK
(
false
)
<<
"The pool2d converter now only support NCHW."
;
}
// TODO(Wilber): How to support dynamic shape?
auto
*
input_producer
=
input
.
getDefiningOp
();
// Process pool_type.
auto
pool_type_attr
=
CreatePoolingType
(
builder
,
pool_type
);
// Update padding.
auto
padding_algorithm_str
=
padding_algorithm
.
str
();
auto
paddings_attr
=
paddings
;
if
(
padding_algorithm_str
==
"EXPLICIT"
)
{
// Do nothing on paddings.
}
else
if
(
padding_algorithm_str
==
"SAME"
)
{
// We should process this case in trt network build phase.
}
else
if
(
padding_algorithm_str
==
"VALID"
)
{
// Set padding to zero.
paddings_attr
=
builder
.
getI32ArrayAttr
({
0
,
0
});
}
else
{
CHECK
(
false
)
<<
"Unknown padding_algotithm."
;
}
// if global_pooling == true or adaptive == true, padding will be ignored
if
(
global_pooling
.
getValue
()
||
adaptive
.
getValue
())
{
paddings_attr
=
builder
.
getI32ArrayAttr
({
0
,
0
});
}
// if global_pooling == true, then we should update kernel size to input dims.
if
(
global_pooling
.
getValue
()
==
true
)
{
// Update ksize to input dims.
}
// The adaptive logic should be processed when we get the context of
// INetworkDefinition, so we place the logic in infrt runtime(trt compile
// time).
// The `exclusive` may be a naive attr, which can be forward to trt.
auto
padding_mode_attr
=
CreatePaddingMode
(
builder
,
padding_algorithm
,
ceil_mode
);
if
(
global_pooling
.
getValue
()
==
true
)
{
CHECK
(
false
)
<<
"Temporarily not support global_pool"
;
return
tblgen_repl_values
;
}
PoolingOp
pool_op
;
{
auto
ods_loc
=
builder
.
getFusedLoc
({
input_producer
->
getLoc
()});
builder
.
create
<
PoolingOp
>
(
ods_loc
,
input
.
getType
(),
input
,
pool_type_attr
,
ksize
,
strides
,
paddings_attr
,
padding_mode_attr
,
exclusive
,
adaptive
,
padding_algorithm
);
}
for
(
auto
v
:
::
llvm
::
SmallVector
<::
mlir
::
Value
,
4
>
{
pool_op
.
getODSResults
(
0
)})
{
tblgen_repl_values
.
push_back
(
v
);
}
return
tblgen_repl_values
;
}
}
// namespace trt
}
// namespace trt
}
// namespace infrt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td
浏览文件 @
2012aeb6
...
@@ -31,9 +31,10 @@ def PD2TRT_Conv2d_Lower : Pat<
...
@@ -31,9 +31,10 @@ def PD2TRT_Conv2d_Lower : Pat<
(PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format),
(PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format),
(createTRTConv2dOp $old_value)>;
(createTRTConv2dOp $old_value)>;
def createTrtPoolingOp : NativeCodeCall<"::infrt::trt::CreatePaddleTrtPoolingOp($_builder, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10)">;
def PD2TRT_Pooling_Lower : Pat<
def PD2TRT_Pooling_Lower : Pat<
(PD_Pool2dOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format, $padding_algorithm),
(PD_Pool2dOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format, $padding_algorithm),
(
TRT_PoolingOp $Input, (INFRT_createI32Attr<"0">)/*kmax*/, $ksize, $strides, $paddings
, $padding_algorithm)>;
(
createTrtPoolingOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format
, $padding_algorithm)>;
def PD2TRT_MatrixMultipl_Lower : Pat<
def PD2TRT_MatrixMultipl_Lower : Pat<
(PD_MulOp $Input1, $Input2, $x_num_col_dims, $y_num_col_dims),
(PD_MulOp $Input1, $Input2, $x_num_col_dims, $y_num_col_dims),
...
...
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
浏览文件 @
2012aeb6
...
@@ -28,33 +28,6 @@
...
@@ -28,33 +28,6 @@
namespace
infrt
{
namespace
infrt
{
namespace
trt
{
namespace
trt
{
#ifdef INFRT_WITH_TRT
#define STRING_TO_ENUM_TYPE(enum_type) enum_type
#define STRING_TO_ENUM_VALUE(enum_value) enum_value
#include <NvInfer.h>
#else // INFRT_WITH_TRT
#define STRING_TO_ENUM_TYPE(enum_type) std::string
#define STRING_TO_ENUM_VALUE(enum_value) #enum_value
#endif // INFRT_WITH_TRT
template
<
typename
T
>
::
mlir
::
IntegerAttr
createNvinferEnumAttr
(
::
mlir
::
PatternRewriter
&
rewriter
,
// NOLINT
T
enum_value
)
{
return
rewriter
.
getSI32IntegerAttr
((
int32_t
)
enum_value
);
}
template
<
>
::
mlir
::
IntegerAttr
createNvinferEnumAttr
<
std
::
string
>
(
::
mlir
::
PatternRewriter
&
rewriter
,
std
::
string
enum_value
)
{
// NOLINT
(
void
)
enum_value
;
return
rewriter
.
getSI32IntegerAttr
(
-
1
);
}
#include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT
#include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT
struct
PD2TRT_GraphLower
:
public
::
mlir
::
RewritePattern
{
struct
PD2TRT_GraphLower
:
public
::
mlir
::
RewritePattern
{
...
...
paddle/infrt/dialect/tensorrt/trt_ops.td
浏览文件 @
2012aeb6
...
@@ -101,7 +101,10 @@ def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> {
...
@@ -101,7 +101,10 @@ def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> {
I32ArrayAttr:$window_size,
I32ArrayAttr:$window_size,
I32ArrayAttr:$strides,
I32ArrayAttr:$strides,
I32ArrayAttr:$paddings,
I32ArrayAttr:$paddings,
StrAttr:$padding_mode
I32Attr:$padding_mode,
BoolAttr:$exclusive,
BoolAttr:$adaptive,
StrAttr:$padding_algorithm
);
);
let results = (outs
let results = (outs
DenseTensor:$output_tensor
DenseTensor:$output_tensor
...
...
paddle/infrt/kernel/tensorrt/trt_helper.h
浏览文件 @
2012aeb6
...
@@ -28,13 +28,13 @@ namespace infrt {
...
@@ -28,13 +28,13 @@ namespace infrt {
namespace
kernel
{
namespace
kernel
{
namespace
tensorrt
{
namespace
tensorrt
{
static
nvinfer1
::
DataType
TensorTypeToWeightType
(
phi
::
DataType
tensor_type
)
{
static
nvinfer1
::
DataType
TensorTypeToWeightType
(
::
phi
::
DataType
tensor_type
)
{
switch
(
tensor_type
)
{
switch
(
tensor_type
)
{
case
phi
::
DataType
::
FLOAT32
:
case
::
phi
::
DataType
::
FLOAT32
:
return
nvinfer1
::
DataType
::
kFLOAT
;
return
nvinfer1
::
DataType
::
kFLOAT
;
case
phi
::
DataType
::
INT32
:
case
::
phi
::
DataType
::
INT32
:
return
nvinfer1
::
DataType
::
kINT32
;
return
nvinfer1
::
DataType
::
kINT32
;
case
phi
::
DataType
::
FLOAT16
:
case
::
phi
::
DataType
::
FLOAT16
:
return
nvinfer1
::
DataType
::
kHALF
;
return
nvinfer1
::
DataType
::
kHALF
;
default:
default:
llvm_unreachable
(
"should not reach here"
);
llvm_unreachable
(
"should not reach here"
);
...
@@ -52,7 +52,7 @@ static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) {
...
@@ -52,7 +52,7 @@ static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) {
return
dims
;
return
dims
;
}
}
static
nvinfer1
::
Weights
TensorToWeights
(
phi
::
DenseTensor
*
tensor
)
{
static
nvinfer1
::
Weights
TensorToWeights
(
::
phi
::
DenseTensor
*
tensor
)
{
CHECK_NOTNULL
(
tensor
);
CHECK_NOTNULL
(
tensor
);
nvinfer1
::
Weights
ret
;
nvinfer1
::
Weights
ret
;
ret
.
type
=
TensorTypeToWeightType
(
tensor
->
dtype
());
ret
.
type
=
TensorTypeToWeightType
(
tensor
->
dtype
());
...
...
paddle/infrt/kernel/tensorrt/trt_kernels.cc
浏览文件 @
2012aeb6
...
@@ -129,6 +129,7 @@ namespace tensorrt {
...
@@ -129,6 +129,7 @@ namespace tensorrt {
// TODO(wilber): Find a way to add layer.
// TODO(wilber): Find a way to add layer.
for
(
auto
&
operation
:
block
.
without_terminator
())
{
for
(
auto
&
operation
:
block
.
without_terminator
())
{
VLOG
(
1
)
<<
"process "
<<
operation
.
getName
().
getStringRef
().
str
()
<<
" ..."
;
if
(
trt
::
ActivationOp
op
=
llvm
::
dyn_cast
<
trt
::
ActivationOp
>
(
operation
))
{
if
(
trt
::
ActivationOp
op
=
llvm
::
dyn_cast
<
trt
::
ActivationOp
>
(
operation
))
{
ActivationFunc
(
ActivationFunc
(
op
,
network
.
get
(),
value_to_trt_tensor_map
,
value_to_tensor_map
);
op
,
network
.
get
(),
value_to_trt_tensor_map
,
value_to_tensor_map
);
...
@@ -138,6 +139,8 @@ namespace tensorrt {
...
@@ -138,6 +139,8 @@ namespace tensorrt {
}
else
if
(
trt
::
ConvolutionOp
op
=
}
else
if
(
trt
::
ConvolutionOp
op
=
llvm
::
dyn_cast
<
trt
::
ConvolutionOp
>
(
operation
))
{
llvm
::
dyn_cast
<
trt
::
ConvolutionOp
>
(
operation
))
{
ConvFunc
(
op
,
network
.
get
(),
value_to_trt_tensor_map
,
value_to_tensor_map
);
ConvFunc
(
op
,
network
.
get
(),
value_to_trt_tensor_map
,
value_to_tensor_map
);
}
else
if
(
trt
::
PoolingOp
op
=
llvm
::
dyn_cast
<
trt
::
PoolingOp
>
(
operation
))
{
PoolFunc
(
op
,
network
.
get
(),
value_to_trt_tensor_map
,
value_to_tensor_map
);
}
else
{
}
else
{
CHECK
(
false
)
<<
"not supported operation."
;
CHECK
(
false
)
<<
"not supported operation."
;
}
}
...
...
paddle/infrt/kernel/tensorrt/trt_layers.h
浏览文件 @
2012aeb6
...
@@ -15,13 +15,15 @@
...
@@ -15,13 +15,15 @@
#pragma once
#pragma once
#include <NvInfer.h>
#include <NvInfer.h>
#include <llvm/ADT/StringRef.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/Value.h>
#include <string>
#include <string>
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/kernel/tensorrt/trt_helper.h"
#include "paddle/infrt/kernel/tensorrt/trt_helper.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
infrt
{
namespace
infrt
{
...
@@ -63,7 +65,12 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT
...
@@ -63,7 +65,12 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT
nvinfer1
::
Dims
dims
=
ArrayAttrToNvDims
(
size_attrs
);
nvinfer1
::
Dims
dims
=
ArrayAttrToNvDims
(
size_attrs
);
auto
kernel_weights
=
auto
kernel_weights
=
TensorToWeights
(
value_to_tensor_map
[
op
.
kernel_weights
()]);
TensorToWeights
(
value_to_tensor_map
[
op
.
kernel_weights
()]);
auto
bias_weights
=
TensorToWeights
(
value_to_tensor_map
[
op
.
bias_weights
()]);
nvinfer1
::
Weights
bias_weights
;
if
(
op
.
bias_weights
()
==
mlir
::
Value
())
{
bias_weights
=
nvinfer1
::
Weights
{};
}
else
{
bias_weights
=
TensorToWeights
(
value_to_tensor_map
[
op
.
bias_weights
()]);
}
auto
*
layer
=
auto
*
layer
=
network
->
addConvolutionNd
(
*
value_to_trt_tensor_map
[
input_tensor_repr
],
network
->
addConvolutionNd
(
*
value_to_trt_tensor_map
[
input_tensor_repr
],
...
@@ -77,6 +84,51 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT
...
@@ -77,6 +84,51 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT
value_to_trt_tensor_map
[
out_repr
]
=
out_tensor
;
value_to_trt_tensor_map
[
out_repr
]
=
out_tensor
;
}
}
inline
void
PoolFunc
(
trt
::
PoolingOp
&
op
,
// NOLINT
nvinfer1
::
INetworkDefinition
*
network
,
ValueToITensorMap
&
value_to_trt_tensor_map
,
// NOLINT
ValueToTensorMap
&
value_to_tensor_map
)
{
// NOLINT
mlir
::
Value
input_tensor_repr
=
op
.
input_tensor
();
nvinfer1
::
ITensor
*
input_itensor
=
value_to_trt_tensor_map
[
input_tensor_repr
];
// nvinfer1::Dims input_shape = input_itensor->getDimensions();
// int input_dims = input_shape.nbDims;
auto
padding_mode
=
op
.
padding_mode
();
auto
pool_type
=
op
.
pool_type
();
mlir
::
ArrayAttr
paddings
=
op
.
paddings
();
mlir
::
ArrayAttr
strides
=
op
.
strides
();
mlir
::
ArrayAttr
ksize
=
op
.
window_size
();
bool
exclusive
=
op
.
exclusive
();
bool
adaptive
=
op
.
adaptive
();
auto
padding_algorithm
=
op
.
padding_algorithm
().
str
();
if
(
padding_algorithm
==
"SAME"
)
{
// TODO(wilber)
CHECK
(
false
)
<<
"Not supported `same` padding algorithm"
;
}
if
(
adaptive
)
{
// TODO(Inference)
CHECK
(
false
)
<<
"Not supported adaptive pool"
;
}
nvinfer1
::
Dims
window_size
=
ArrayAttrToNvDims
(
ksize
);
auto
*
layer
=
network
->
addPoolingNd
(
*
input_itensor
,
static_cast
<
nvinfer1
::
PoolingType
>
(
pool_type
),
window_size
);
CHECK_NOTNULL
(
layer
);
layer
->
setPaddingMode
(
static_cast
<
nvinfer1
::
PaddingMode
>
(
padding_mode
));
layer
->
setPaddingNd
(
ArrayAttrToNvDims
(
paddings
));
layer
->
setStrideNd
(
ArrayAttrToNvDims
(
strides
));
layer
->
setAverageCountExcludesPadding
(
exclusive
);
mlir
::
Value
out_repr
=
op
.
output_tensor
();
nvinfer1
::
ITensor
*
out_tensor
=
layer
->
getOutput
(
0
);
value_to_trt_tensor_map
[
out_repr
]
=
out_tensor
;
}
inline
void
FcFunc
(
trt
::
FullyConnectedOp
&
op
,
// NOLINT
inline
void
FcFunc
(
trt
::
FullyConnectedOp
&
op
,
// NOLINT
nvinfer1
::
INetworkDefinition
*
network
,
nvinfer1
::
INetworkDefinition
*
network
,
ValueToITensorMap
&
value_to_trt_tensor_map
,
// NOLINT
ValueToITensorMap
&
value_to_trt_tensor_map
,
// NOLINT
...
...
paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir
已删除
100644 → 0
浏览文件 @
ad0c106c
// RUN: infrtexec -i %s | FileCheck %s
// CHECK-LABEL: @run_trt
func @run_trt(%0 : !infrt.dense_tensor<GPU, FP32, NCHW>, %ctx : !phi.context<GPU>) {
%a = "trt.create_engine"(%0) ({
%1 = "trt.Activation"(%0) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
"infrt.return"(%1) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
}) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !trt.engine
"trt.inspect_engine"(%a) {} : (!trt.engine) -> ()
%res = "trt.compute"(%a, %ctx) {} : (!trt.engine, !phi.context<GPU>) -> (!infrt.tensor_list)
%size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32)
"infrt.print.i32"(%size) {} : (i32) -> ()
%ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
infrt.return
}
// CHECK-LABEL: @main
func @main() {
%ctx = "phi_dt.create_context.gpu" (): () -> !phi.context<GPU>
%t = "phi_dt.create_dense_tensor.gpu" (%ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context<GPU>) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
"phi_dt.print_tensor" (%t) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
//%res =
infrt.call @run_trt(%t, %ctx) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !phi.context<GPU>) -> ()
//-> (!infrt.dense_tensor<GPU, FP32, NCHW>)
infrt.return
}
paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir
0 → 100644
浏览文件 @
2012aeb6
module {
func @main_graph(%arg0: !infrt.dense_tensor<CPU, FP32, ANY>) -> !infrt.dense_tensor<CPU, FP32, ANY> {
%0 = "phi_dt.create_context.gpu"() : () -> !phi.context<GPU>
%1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor<CPU, FP32, ANY>, !phi.context<GPU>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%2 = "trt.create_engine"(%1) ( {
%6 = "trt.Activation"(%1) {activation_type = 1 : si32, alpha = 0.000000e+00 : f32, beta = 0.000000e+00 : f32} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
infrt.return %6 : !infrt.dense_tensor<GPU, FP32, NCHW>
}) {run_once = true} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !trt.engine
%3 = "trt.compute"(%2, %0) : (!trt.engine, !phi.context<GPU>) -> !infrt.tensor_list
%4 = "dt.tensor_list_get_tensor"(%3) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%5 = "phi_dt.memcpy.gpu"(%4, %0) {d2h = true} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !phi.context<GPU>) -> !infrt.dense_tensor<CPU, FP32, ANY>
infrt.return %5 : !infrt.dense_tensor<CPU, FP32, ANY>
}
func @main() {
%0 = "phi_dt.create_context.cpu"() : () -> !phi.context<CPU>
%1 = "phi_dt.create_inited_dense_tensor.cpu.f32"(%0) {dims = [3, 6, 1, 1], layout = #infrt.layout<NCHW>, lod = [0], value = 1.500000e+00 : f32} : (!phi.context<CPU>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%2 = infrt.call @main_graph(%1) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
phi_dt.print_tensor(%2 : !infrt.dense_tensor<CPU, FP32, NCHW>)
infrt.return
}
}
paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir
浏览文件 @
2012aeb6
// RUN: infrtexec -i %s | FileCheck %s
module {
func @main_graph(%arg0: !infrt.dense_tensor<CPU, FP32, ANY>) -> !infrt.dense_tensor<CPU, FP32, ANY> {
// CHECK-LABEL: @main
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
func @main() {
%0 = "phi_dt.create_context.gpu"() : () -> !phi.context<GPU>
%ctx = "phi_dt.create_context.gpu" (): () -> !phi.context<GPU>
%1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor<CPU, FP32, ANY>, !phi.context<GPU>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%cpu_ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%4 = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout<NCHW>, lod=[0], dims=[2, 6]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%3 = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout<NCHW>, lod=[0], dims=[2]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%input_tensor = "phi_dt.create_dense_tensor.gpu" (%ctx) {
%5 = "trt.create_engine"(%1, %4, %3) ( {
precision=#infrt.precision<FP32>,
%10 = "trt.FullyConnected"(%1, %4, %3) {out_channel_num = 2 : si32} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
layout=#infrt.layout<NCHW>,
infrt.return %10 : !infrt.dense_tensor<GPU, FP32, NCHW>
dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context<GPU>) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
}) {run_once = true} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !trt.engine
"phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
%6 = "trt.compute"(%5, %0) : (!trt.engine, !phi.context<GPU>) -> !infrt.tensor_list
//"phi_dt.print_tensor" (%input_tensor) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
%7 = "dt.tensor_list_get_tensor"(%6) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%8 = "phi_dt.memcpy.gpu"(%7, %0) {d2h = true} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !phi.context<GPU>) -> !infrt.dense_tensor<CPU, FP32, ANY>
%kernel_weight = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) {
infrt.return %8 : !infrt.dense_tensor<CPU, FP32, ANY>
precision=#infrt.precision<FP32>,
}
layout=#infrt.layout<NCHW>,
dims=[2:i64, 3:i64], lod=[1:i64]} : (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
func @main() {
"phi_dt.fill_dense_tensor.f32"(%kernel_weight) {value=[1.:f32, 2.:f32, 3.:f32, 4.:f32, 5.:f32, 6.:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
//"phi_dt.print_tensor" (%kernel_weight) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%input_tensor = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout<NCHW>, lod=[0], dims=[3, 6, 1, 1]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%res = infrt.call @main_graph(%input_tensor) {} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%kernel_bias = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) {
"phi_dt.print_tensor" (%res) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
precision=#infrt.precision<FP32>,
infrt.return
layout=#infrt.layout<NCHW>,
}
dims=[2:i64], lod=[1:i64]} : (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%kernel_bias) {value=[1.:f32, 2.:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
//"phi_dt.print_tensor" (%kernel_bias) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%engine = "trt.create_engine"(%input_tensor, %kernel_weight, %kernel_bias) ({
%1 = "trt.Activation"(%input_tensor) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%2 = "trt.FullyConnected"(%input_tensor, %kernel_weight, %kernel_bias) {out_channel_num = 2 : si32} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
"infrt.return"(%1, %2) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
}) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !trt.engine
%res = "trt.compute"(%engine, %ctx) {} : (!trt.engine, !phi.context<GPU>) -> (!infrt.tensor_list)
%size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32)
"infrt.print.i32"(%size) {} : (i32) -> ()
%ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
%ts1 = "dt.tensor_list_get_tensor"(%res) {id = 1 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.print_tensor" (%ts1) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
infrt.return
}
}
paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir
0 → 100644
浏览文件 @
2012aeb6
module {
func @main_graph(%arg0: !infrt.dense_tensor<CPU, FP32, ANY>) -> !infrt.dense_tensor<CPU, FP32, ANY> {
%0 = "phi_dt.create_context.gpu"() : () -> !phi.context<GPU>
%1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor<CPU, FP32, ANY>, !phi.context<GPU>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%2 = "trt.create_engine"(%1) ( {
%6 = "trt.Pooling"(%1) {padding_mode = 0 : i32, paddings = [1 : i32, 1 : i32], pool_type = 0 : i32, strides = [2 : i32, 2 : i32], window_size = [3 : i32, 3 : i32], exclusive = false, adaptive = false, padding_algorithm = "EXPLICIT"} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
infrt.return %6 : !infrt.dense_tensor<GPU, FP32, NCHW>
}) {run_once = true} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !trt.engine
%3 = "trt.compute"(%2, %0) : (!trt.engine, !phi.context<GPU>) -> !infrt.tensor_list
%4 = "dt.tensor_list_get_tensor"(%3) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%5 = "phi_dt.memcpy.gpu"(%4, %0) {d2h = true} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !phi.context<GPU>) -> !infrt.dense_tensor<CPU, FP32, ANY>
infrt.return %5 : !infrt.dense_tensor<CPU, FP32, ANY>
}
func @main() {
%0 = "phi_dt.create_context.cpu"() : () -> !phi.context<CPU>
%1 = "phi_dt.create_inited_dense_tensor.cpu.f32"(%0) {dims = [1, 3, 10, 10], layout = #infrt.layout<NCHW>, lod = [0], value = 1.500000e+00 : f32} : (!phi.context<CPU>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%2 = infrt.call @main_graph(%1) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
phi_dt.print_tensor(%2 : !infrt.dense_tensor<CPU, FP32, NCHW>)
infrt.return
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录