Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d194bd3a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d194bd3a
编写于
6月 05, 2021
作者:
W
Wilber
提交者:
GitHub
6月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-TRT] Add gather_nd and reduce_sum trt op. (#33324)
上级
dd181238
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
935 addition
and
27 deletion
+935
-27
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+2
-0
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+13
-4
paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc
paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc
+58
-0
paddle/fluid/inference/tensorrt/convert/reduce_op.cc
paddle/fluid/inference/tensorrt/convert/reduce_op.cc
+90
-0
paddle/fluid/inference/tensorrt/convert/reshape_op.cc
paddle/fluid/inference/tensorrt/convert/reshape_op.cc
+1
-1
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+42
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.cu
...le/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.cu
+229
-0
paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h
+132
-0
paddle/fluid/operators/math/bert_encoder_functor.cu
paddle/fluid/operators/math/bert_encoder_functor.cu
+190
-22
python/paddle/fluid/tests/unittests/ir/inference/test_trt_gather_nd_op.py
...uid/tests/unittests/ir/inference/test_trt_gather_nd_op.py
+93
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_reduce_sum_op.py
...id/tests/unittests/ir/inference/test_trt_reduce_sum_op.py
+82
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
d194bd3a
...
...
@@ -1237,6 +1237,8 @@ USE_TRT_CONVERTER(affine_channel);
USE_TRT_CONVERTER
(
multiclass_nms
);
USE_TRT_CONVERTER
(
nearest_interp
);
USE_TRT_CONVERTER
(
reshape
);
USE_TRT_CONVERTER
(
reduce_sum
);
USE_TRT_CONVERTER
(
gather_nd
);
#endif
namespace
paddle_infer
{
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
d194bd3a
...
...
@@ -13,6 +13,8 @@ nv_library(tensorrt_converter
multiclass_nms_op.cc
nearest_interp_op.cc
reshape_op.cc
reduce_op.cc
gather_nd_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
...
...
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
d194bd3a
...
...
@@ -40,10 +40,19 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
auto
word_emb_name
=
op_desc
.
Input
(
"WordEmbedding"
).
front
();
auto
pos_emb_name
=
op_desc
.
Input
(
"PosEmbedding"
).
front
();
auto
sent_emb_name
=
op_desc
.
Input
(
"SentEmbedding"
).
front
();
std
::
vector
<
std
::
string
>
id_names
=
{
word_id_name
,
pos_id_name
,
sent_id_name
};
std
::
vector
<
std
::
string
>
emb_names
=
{
word_emb_name
,
pos_emb_name
,
sent_emb_name
};
std
::
vector
<
std
::
string
>
id_names
;
std
::
vector
<
std
::
string
>
emb_names
;
if
(
engine_
->
use_oss
())
{
id_names
=
std
::
vector
<
std
::
string
>
{
word_id_name
,
pos_id_name
,
sent_id_name
};
emb_names
=
std
::
vector
<
std
::
string
>
{
word_emb_name
,
pos_emb_name
,
sent_emb_name
};
}
else
{
id_names
=
op_desc
.
Input
(
"Ids"
);
emb_names
=
op_desc
.
Input
(
"Embs"
);
}
int
input_num
=
id_names
.
size
();
...
...
paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc
0 → 100644
浏览文件 @
d194bd3a
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
GatherNdOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a paddle gather_nd op to tensorrt gather_nd plugin"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
std
::
vector
<
nvinfer1
::
ITensor
*>
inputs
;
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
index
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Index"
)[
0
]);
inputs
.
emplace_back
(
input
);
inputs
.
emplace_back
(
index
);
nvinfer1
::
ILayer
*
layer
=
nullptr
;
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
GatherNdPluginDynamic
*
plugin
=
new
plugin
::
GatherNdPluginDynamic
(
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
inputs
.
size
(),
plugin
);
std
::
string
layer_name
=
"gather_nd (Output: "
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
layer_name
+=
output_name
;
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
layer
->
setName
((
layer_name
+
")"
).
c_str
());
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
gather_nd
,
GatherNdOpConverter
);
paddle/fluid/inference/tensorrt/convert/reduce_op.cc
0 → 100644
浏览文件 @
d194bd3a
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <NvInfer.h>
#include <sys/types.h>
#include <cstddef>
#include <cstdint>
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
ReduceSumOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a paddle reduce_sum op to tensorrt reduce layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
*
x
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
).
front
());
nvinfer1
::
Dims
input_shape
=
x
->
getDimensions
();
int
input_dims
=
input_shape
.
nbDims
;
bool
keep_dim
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"keep_dim"
));
std
::
vector
<
int32_t
>
dim
=
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
op_desc
.
GetAttr
(
"dim"
));
bool
reduce_all
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"reduce_all"
));
// Now we only support dynamic_shape mode.
nvinfer1
::
IReduceLayer
*
layer
=
nullptr
;
if
(
reduce_all
)
{
uint32_t
reduce_dim
=
0
;
for
(
int
i
=
0
;
i
<
input_dims
;
++
i
)
{
reduce_dim
|=
1
<<
i
;
}
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
x
,
nvinfer1
::
ReduceOperation
::
kSUM
,
reduce_dim
,
keep_dim
);
}
else
{
auto
CvtToBitMask
=
[
&
](
const
std
::
vector
<
int32_t
>&
dims
)
->
uint32_t
{
uint32_t
res
=
0
;
for
(
auto
x
:
dims
)
{
if
(
x
<
0
)
{
res
|=
1
<<
(
x
+
input_dims
);
}
else
{
res
|=
1
<<
x
;
}
}
return
res
;
};
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
x
,
nvinfer1
::
ReduceOperation
::
kSUM
,
CvtToBitMask
(
dim
),
keep_dim
);
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"reduce_sum"
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
reduce_sum
,
ReduceSumOpConverter
);
paddle/fluid/inference/tensorrt/convert/reshape_op.cc
浏览文件 @
d194bd3a
...
...
@@ -34,7 +34,7 @@ class ReshapeOpConverter : public OpConverter {
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
const
std
::
vector
<
int
>&
shape
=
std
::
vector
<
int
>
shape
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"shape"
));
int
nbDims_num
=
shape
.
size
();
nvinfer1
::
Dims
reshape_dim
;
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
d194bd3a
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/data_layout.h"
...
...
@@ -122,11 +123,13 @@ struct SimpleOpTypeSetTeller : public Teller {
"flatten2"
,
"flatten"
,
"gather"
,
"gather_nd"
,
"yolo_box"
,
"roi_align"
,
"affine_channel"
,
"nearest_interp"
,
"anchor_generator"
,
"reduce_sum"
,
};
};
...
...
@@ -324,6 +327,30 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if
(
!
with_dynamic_shape
||
desc
.
Input
(
"Axis"
).
size
()
>
0
)
return
false
;
}
if
(
op_type
==
"gather_nd"
)
{
auto
*
block
=
desc
.
Block
();
auto
x_var_name
=
desc
.
Input
(
"X"
)[
0
];
auto
index_var_name
=
desc
.
Input
(
"Index"
)[
0
];
auto
*
x_var_desc
=
block
->
FindVar
(
x_var_name
);
auto
*
index_var_desc
=
block
->
FindVar
(
index_var_name
);
// The index input must be int32 datatype.
if
(
index_var_desc
->
GetDataType
()
!=
paddle
::
framework
::
proto
::
VarType_Type
::
VarType_Type_INT32
)
{
VLOG
(
3
)
<<
"gather_nd op Index input data type must be int32"
;
return
false
;
}
const
auto
index_shape
=
index_var_desc
->
GetShape
();
const
auto
x_shape
=
x_var_desc
->
GetShape
();
if
(
x_shape
.
size
()
!=
index_shape
.
size
())
{
VLOG
(
3
)
<<
"gather_nd op Index input dims size ["
<<
index_shape
.
size
()
<<
" ] not equal to x dims size ["
<<
x_shape
.
size
()
<<
"]"
;
return
false
;
}
if
(
!
with_dynamic_shape
)
return
false
;
}
if
(
op_type
==
"yolo_box"
)
{
if
(
with_dynamic_shape
)
return
false
;
bool
has_attrs
=
...
...
@@ -684,6 +711,21 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if
(
shape
.
size
()
>=
nvinfer1
::
Dims
::
MAX_DIMS
)
return
false
;
}
}
if
(
op_type
==
"reduce_sum"
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"the reduce_sum does not support static shape yet"
;
return
false
;
}
if
(
!
(
desc
.
HasAttr
(
"keep_dim"
)
&&
desc
.
HasAttr
(
"dim"
)
&&
desc
.
HasAttr
(
"reduce_all"
)))
{
VLOG
(
3
)
<<
"the reduce_sum does not have attr (keep_dim or dim or "
"reduce_all)"
;
return
false
;
}
}
if
((
*
teller
)(
op_type
,
desc
,
use_no_calib_int8
))
return
true
;
}
return
false
;
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
d194bd3a
...
...
@@ -8,6 +8,7 @@ nv_library(tensorrt_plugin
anchor_generator_op_plugin.cu
yolo_box_op_plugin.cu
roi_align_op_plugin.cu
gather_nd_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
nv_test
(
test_split_plugin SRCS test_split_plugin.cc DEPS
...
...
paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.cu
0 → 100644
浏览文件 @
d194bd3a
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_fp16.h>
#include <algorithm>
#include <cstdint>
#include <functional>
#include <numeric>
#include <sstream>
#include "NvInferRuntimeCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
GatherNdCUDAKernel
(
const
T
*
input
,
const
int32_t
*
input_dims
,
const
IndexT
*
indices
,
T
*
output
,
int32_t
remain_size
,
int32_t
slice_size
,
int32_t
end_size
)
{
CUDA_KERNEL_LOOP
(
i
,
remain_size
*
slice_size
)
{
int
indices_i
=
i
/
slice_size
;
int
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
IndexT
gather_i
=
0
;
int32_t
temp
=
slice_size
;
for
(
int32_t
j
=
end_size
-
1
;
j
>=
0
;
--
j
)
{
auto
index_value
=
indices
[
indices_i
*
end_size
+
j
];
PADDLE_ENFORCE
(
index_value
>=
0
&&
index_value
<
input_dims
[
j
],
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater or equal to 0, but received [%d]"
,
input_dims
[
j
],
index_value
);
gather_i
+=
(
index_value
*
temp
);
temp
*=
input_dims
[
j
];
}
IndexT
input_i
=
gather_i
+
slice_i
;
*
(
output
+
i
)
=
*
(
input
+
input_i
);
}
}
int
GatherNdPluginDynamic
::
initialize
()
{
return
0
;
}
size_t
GatherNdPluginDynamic
::
getSerializationSize
()
const
{
return
SerializedSize
(
with_fp16_
);
}
void
GatherNdPluginDynamic
::
serialize
(
void
*
buffer
)
const
{
SerializeValue
(
&
buffer
,
with_fp16_
);
}
nvinfer1
::
DimsExprs
GatherNdPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
PADDLE_ENFORCE_EQ
(
nb_inputs
,
2
,
platform
::
errors
::
InvalidArgument
(
"The gather_nd plugin should have 2 input, but got %d."
,
nb_inputs
));
PADDLE_ENFORCE_EQ
(
output_index
,
0
,
platform
::
errors
::
InvalidArgument
(
"When GetOutputDimensions in gather_nd "
"plugin, the output_index should be 0."
));
nvinfer1
::
DimsExprs
x_dims
=
inputs
[
0
];
nvinfer1
::
DimsExprs
index_dims
=
inputs
[
1
];
int32_t
x_dims_size
=
x_dims
.
nbDims
;
int32_t
index_dims_size
=
index_dims
.
nbDims
;
// TODO(wilber): The result dims shoule be Index.shape[:-1] +
// X.shape[Index.shape[-1]:], but the trt DimsExprs is an expression we can't
// get the actual value. So we only support one scenario: input_dims.size ==
// index_dims.size.
nvinfer1
::
DimsExprs
ret
(
x_dims
);
for
(
int
i
=
0
;
i
<
index_dims_size
-
1
;
++
i
)
{
ret
.
d
[
i
]
=
index_dims
.
d
[
i
];
}
return
ret
;
}
bool
GatherNdPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of gather_nd plugin should not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
(
in_out
&&
pos
<
(
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
else
if
(
pos
==
1
)
{
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
else
if
(
pos
==
2
)
{
return
in
.
type
==
in_out
[
0
].
type
&&
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
return
true
;
}
nvinfer1
::
DataType
GatherNdPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
return
input_types
[
0
];
}
int
GatherNdPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
input_dims
=
input_desc
[
0
].
dims
;
auto
index_dims
=
input_desc
[
1
].
dims
;
auto
input_dims_size
=
input_dims
.
nbDims
;
auto
index_dims_size
=
index_dims
.
nbDims
;
std
::
vector
<
int32_t
>
input_shape
,
index_shape
,
out_shape
;
for
(
int
i
=
0
;
i
<
input_dims
.
nbDims
;
i
++
)
input_shape
.
push_back
(
input_dims
.
d
[
i
]);
for
(
int
i
=
0
;
i
<
index_dims
.
nbDims
;
i
++
)
index_shape
.
push_back
(
index_dims
.
d
[
i
]);
// The out_shape is
// Index.shape[:-1] + X.shape[Index.shape[-1]:]
for
(
int
i
=
0
;
i
<
index_dims_size
-
1
;
++
i
)
{
out_shape
.
emplace_back
(
index_shape
[
i
]);
}
for
(
int
i
=
index_shape
[
index_dims_size
-
1
];
i
<
input_dims_size
;
++
i
)
{
out_shape
.
emplace_back
(
input_shape
[
i
]);
}
// final dim
int
end_size
=
index_shape
[
index_dims_size
-
1
];
// remain dim
std
::
vector
<
int
>
remain_ddim
(
index_shape
.
begin
(),
index_shape
.
end
()
-
1
);
int
remain_numel
=
std
::
accumulate
(
remain_ddim
.
begin
(),
remain_ddim
.
end
(),
1
,
std
::
multiplies
<
int
>
());
// slice size
int
slice_size
=
1
;
for
(
int
i
=
end_size
;
i
<
input_dims_size
;
++
i
)
{
slice_size
*=
input_shape
[
i
];
}
auto
input_type
=
input_desc
[
0
].
type
;
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. gather_nd-->fp32"
;
const
float
*
p_input
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
int32_t
*
p_index
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
float
*
p_output
=
static_cast
<
float
*>
(
outputs
[
0
]);
if
(
input_dims_data_
==
nullptr
)
{
cudaMalloc
(
&
input_dims_data_
,
input_shape
.
size
()
*
sizeof
(
int
));
}
cudaMemcpyAsync
(
input_dims_data_
,
input_shape
.
data
(),
sizeof
(
int
)
*
input_shape
.
size
(),
cudaMemcpyHostToDevice
,
stream
);
int
block
=
512
;
int
n
=
slice_size
*
remain_numel
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
GatherNdCUDAKernel
<
float
,
int32_t
><<<
grid
,
block
,
0
,
stream
>>>
(
p_input
,
input_dims_data_
,
p_index
,
p_output
,
remain_numel
,
slice_size
,
end_size
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. gather_nd-->fp16"
;
const
half
*
p_input
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
p_index
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
half
*
p_output
=
static_cast
<
half
*>
(
outputs
[
0
]);
if
(
input_dims_data_
==
nullptr
)
{
cudaMalloc
(
&
input_dims_data_
,
input_shape
.
size
()
*
sizeof
(
int
));
}
cudaMemcpyAsync
(
input_dims_data_
,
input_shape
.
data
(),
sizeof
(
int
)
*
input_shape
.
size
(),
cudaMemcpyHostToDevice
,
stream
);
int
block
=
512
;
int
n
=
slice_size
*
remain_numel
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
GatherNdCUDAKernel
<
half
,
int32_t
><<<
grid
,
block
,
0
,
stream
>>>
(
p_input
,
input_dims_data_
,
p_index
,
p_output
,
remain_numel
,
slice_size
,
end_size
);
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h
0 → 100644
浏览文件 @
d194bd3a
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <thrust/device_vector.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
GatherNdPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
explicit
GatherNdPluginDynamic
(
bool
with_fp16
)
{
with_fp16_
=
with_fp16
;
}
GatherNdPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
GatherNdPluginDynamic
(
with_fp16_
);
}
const
char
*
getPluginType
()
const
override
{
return
"gather_nd_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
;
size_t
getSerializationSize
()
const
override
;
void
serialize
(
void
*
buffer
)
const
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
override
{}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
override
;
void
destroy
()
override
{
if
(
input_dims_data_
)
{
cudaFree
(
input_dims_data_
);
}
delete
this
;
}
private:
int32_t
*
input_dims_data_
{
nullptr
};
};
class
GatherNdPluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
GatherNdPluginDynamicCreator
()
{}
const
char
*
getPluginName
()
const
override
{
return
"gather_nd_plugin"
;
}
const
char
*
getPluginVersion
()
const
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
{
auto
plugin
=
new
GatherNdPluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
GatherNdPluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/math/bert_encoder_functor.cu
浏览文件 @
d194bd3a
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
...
...
@@ -311,6 +312,156 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(
#endif
}
template
<
typename
T
>
__global__
void
SoftmaxKernelWithEltaddForLarge
(
T
*
qk_buf
,
const
T
*
bias_qk
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
T
stride_max
=
-
1e20
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
stride_max
=
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
>
stride_max
?
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
:
stride_max
;
}
T
max_val
=
blockReduceMax
<
T
>
(
stride_max
,
mask
);
T
stride_sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
stride_sum
+=
__expf
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
-
max_val
);
}
T
sum_val
=
blockReduceSum
<
T
>
(
stride_sum
,
mask
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
=
(
T
)(
__expf
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
-
max_val
)
/
sum_val
);
}
}
// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
template
<
>
__global__
void
SoftmaxKernelWithEltaddForLarge
(
half
*
qk_buf
,
const
half
*
bias_qk
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
float
stride_max
=
-
1e20
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float
tmp
=
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_max
=
tmp
>
stride_max
?
tmp
:
stride_max
;
}
float
max_val
=
blockReduceMax
<
float
>
(
stride_max
,
mask
);
float
stride_sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float
tmp
=
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_sum
+=
__expf
(
tmp
-
max_val
);
}
float
sum_val
=
blockReduceSum
<
float
>
(
stride_sum
,
mask
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float
tmp
=
__expf
(
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
])
-
max_val
);
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
=
(
half
)(
tmp
/
sum_val
);
}
#endif
}
#endif // @} End Half kernel: SoftmaxKernelWithEltadd
template
<
typename
T
>
__global__
void
SoftmaxKernelWithEltaddForLarge2
(
T
*
qk_buf_
,
const
T
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
float2
stride_max
=
make_float2
(
-
1e20
f
,
-
1e20
f
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_max
.
x
=
max
(
stride_max
.
x
,
cur
.
x
);
stride_max
.
y
=
max
(
stride_max
.
y
,
cur
.
y
);
}
float
max_val
=
blockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
float2
stride_sum
=
make_float2
(
0.
f
,
0.
f
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_sum
.
x
+=
__expf
(
cur
.
x
-
max_val
);
stride_sum
.
y
+=
__expf
(
cur
.
y
-
max_val
);
}
float
sum_val
=
blockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
1e-6
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
=
FloatsToPair
<
T
>
(
__expf
(
cur
.
x
-
max_val
)
/
sum_val
,
__expf
(
cur
.
y
-
max_val
)
/
sum_val
);
}
}
template
<
>
__global__
void
SoftmaxKernelWithEltaddForLarge2
(
half2
*
qk_buf_
,
const
half2
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
float2
stride_max
=
make_float2
(
-
1e20
f
,
-
1e20
f
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_max
.
x
=
max
(
stride_max
.
x
,
cur
.
x
);
stride_max
.
y
=
max
(
stride_max
.
y
,
cur
.
y
);
}
float
max_val
=
blockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
float2
stride_sum
=
make_float2
(
0.
f
,
0.
f
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_sum
.
x
+=
__expf
(
cur
.
x
-
max_val
);
stride_sum
.
y
+=
__expf
(
cur
.
y
-
max_val
);
}
float
sum_val
=
blockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
1e-6
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
=
FloatsToPair
<
half2
>
(
__expf
(
cur
.
x
-
max_val
)
/
sum_val
,
__expf
(
cur
.
y
-
max_val
)
/
sum_val
);
}
#endif
}
template
<
typename
T
>
inline
void
MatMulWithHeadQK
(
const
platform
::
CUDADeviceContext
&
context
,
int
head_num
,
int
seq_len
,
int
size_per_head
,
...
...
@@ -332,14 +483,11 @@ inline void MatMulWithHeadQK(const platform::CUDADeviceContext &context,
reinterpret_cast
<
run_type
*>
(
qk_buf_
),
batch_size
*
head_num
,
seq_len
*
size_per_head
,
seq_len
*
size_per_head
);
if
(
seq_len
<=
1024
)
{
int
grid
=
batch_size
*
head_num
*
seq_len
;
int
block
=
seq_len
;
// Align block to 32, also limit seq_len to max block size.
PADDLE_ENFORCE_LE
(
seq_len
,
1024
,
platform
::
errors
::
InvalidArgument
(
"seq_len should <= 1024, "
"but received seq_len is:%d"
,
seq_len
));
if
(
seq_len
%
2
==
0
)
{
block
=
(
seq_len
<=
64
)
?
32
:
((
seq_len
+
63
)
/
64
)
*
32
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
...
...
@@ -358,6 +506,26 @@ inline void MatMulWithHeadQK(const platform::CUDADeviceContext &context,
SoftmaxKernelWithEltadd
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
qk_buf_
,
bias_qk
,
batch_size
,
head_num
,
seq_len
,
FINAL_MASK
);
}
}
else
{
int
grid
=
batch_size
*
head_num
*
seq_len
;
int
block
=
512
;
if
(
seq_len
%
2
==
0
)
{
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
SoftmaxKernelWithEltaddForLarge2
<
float2
><<<
grid
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
float2
*>
(
qk_buf_
),
reinterpret_cast
<
const
float2
*>
(
bias_qk
),
batch_size
,
head_num
,
seq_len
/
2
,
FINAL_MASK
);
}
else
{
SoftmaxKernelWithEltaddForLarge2
<
__half2
><<<
grid
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
__half2
*>
(
qk_buf_
),
reinterpret_cast
<
const
__half2
*>
(
bias_qk
),
batch_size
,
head_num
,
seq_len
/
2
,
FINAL_MASK
);
}
}
else
{
SoftmaxKernelWithEltaddForLarge
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
qk_buf_
,
bias_qk
,
batch_size
,
head_num
,
seq_len
,
FINAL_MASK
);
}
}
}
template
<
typename
T
>
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_gather_nd_op.py
0 → 100644
浏览文件 @
d194bd3a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
PassVersionChecker
from
paddle.fluid.core
import
AnalysisConfig
class
TRTGatherNdTest
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
4
],
dtype
=
"float32"
)
index
=
fluid
.
data
(
name
=
"index"
,
shape
=
[
-
1
,
2
,
2
],
dtype
=
"int32"
)
gather_nd
=
fluid
.
layers
.
gather_nd
(
data
,
index
)
out
=
fluid
.
layers
.
batch_norm
(
gather_nd
,
is_test
=
True
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
2
,
3
,
4
]).
astype
(
"float32"
),
"index"
:
np
.
array
([[[
0
,
1
],
[
1
,
0
]],
[[
1
,
2
],
[
0
,
1
]]]).
astype
(
"int32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TRTGatherNdTest
.
TensorRTParam
(
1
<<
30
,
32
,
1
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
fetch_list
=
[
out
]
self
.
dynamic_shape_params
=
TRTGatherNdTest
.
DynamicShapeParam
({
'data'
:
[
1
,
3
,
4
],
'index'
:
[
1
,
2
,
2
]
},
{
'data'
:
[
3
,
3
,
4
],
'index'
:
[
3
,
2
,
2
]},
{
'data'
:
[
3
,
3
,
4
],
'index'
:
[
3
,
2
,
2
]},
False
)
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
flatten
=
True
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
class
TRTGatherNdFp16Test
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
5120
,
768
],
dtype
=
"float32"
)
index
=
fluid
.
data
(
name
=
"index"
,
shape
=
[
-
1
,
4096
,
2
],
dtype
=
"int32"
)
gather_nd
=
fluid
.
layers
.
gather_nd
(
data
,
index
)
out
=
fluid
.
layers
.
batch_norm
(
gather_nd
,
is_test
=
True
)
index_data
=
np
.
zeros
((
1
,
4096
,
2
),
dtype
=
'int32'
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
1
,
5120
,
768
]).
astype
(
"float32"
),
"index"
:
index_data
,
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TRTGatherNdFp16Test
.
TensorRTParam
(
1
<<
30
,
32
,
1
,
AnalysisConfig
.
Precision
.
Half
,
False
,
False
)
self
.
fetch_list
=
[
out
]
self
.
dynamic_shape_params
=
TRTGatherNdFp16Test
.
DynamicShapeParam
({
'data'
:
[
1
,
5120
,
768
],
'index'
:
[
1
,
4096
,
2
]
},
{
'data'
:
[
3
,
5120
,
768
],
'index'
:
[
3
,
4096
,
2
]},
{
'data'
:
[
3
,
5120
,
768
],
'index'
:
[
3
,
4096
,
2
]},
False
)
def
test_check_output
(
self
,
atol
=
1e-3
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
flatten
=
True
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/inference/test_trt_reduce_sum_op.py
0 → 100644
浏览文件 @
d194bd3a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
PassVersionChecker
from
paddle.fluid.core
import
AnalysisConfig
class
TRTReduceSumTest
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
10
,
768
],
dtype
=
"float32"
)
reduce_sum
=
fluid
.
layers
.
reduce_sum
(
data
,
dim
=
[
2
,
-
1
],
keep_dim
=
True
)
out
=
fluid
.
layers
.
batch_norm
(
reduce_sum
,
is_test
=
True
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
3
,
3
,
10
,
768
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TRTReduceSumTest
.
TensorRTParam
(
1
<<
30
,
32
,
1
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
fetch_list
=
[
out
]
self
.
dynamic_shape_params
=
TRTReduceSumTest
.
DynamicShapeParam
({
'data'
:
[
1
,
3
,
8
,
8
]
},
{
'data'
:
[
3
,
3
,
10
,
768
]},
{
'data'
:
[
3
,
3
,
10
,
768
]},
False
)
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
flatten
=
True
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
class
TRTReduceSumAllTest
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
10
,
768
],
dtype
=
"float32"
)
reduce_sum
=
fluid
.
layers
.
reduce_sum
(
data
,
keep_dim
=
True
)
out
=
fluid
.
layers
.
batch_norm
(
reduce_sum
,
is_test
=
True
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
3
,
3
,
10
,
768
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TRTReduceSumAllTest
.
TensorRTParam
(
1
<<
30
,
32
,
1
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
fetch_list
=
[
out
]
self
.
dynamic_shape_params
=
TRTReduceSumAllTest
.
DynamicShapeParam
({
'data'
:
[
1
,
3
,
8
,
8
]
},
{
'data'
:
[
3
,
3
,
10
,
768
]},
{
'data'
:
[
3
,
3
,
10
,
768
]},
False
)
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
flatten
=
True
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录