Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
758fccfe
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
758fccfe
编写于
12月 01, 2022
作者:
Z
Zhang Jun
提交者:
GitHub
12月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[inference][trt] dynamic shape support for Instance norm (#47998)
* instance norm support dynamic shape * update unittest
上级
1b1d6d3f
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
258 addition
and
11 deletion
+258
-11
paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc
paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc
+10
-4
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+0
-4
paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu
...luid/inference/tensorrt/plugin/instance_norm_op_plugin.cu
+109
-0
paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h
...fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h
+132
-2
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py
.../unittests/ir/inference/test_trt_convert_instance_norm.py
+7
-1
未找到文件。
paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc
浏览文件 @
758fccfe
...
@@ -74,10 +74,16 @@ class InstanceNormOpConverter : public OpConverter {
...
@@ -74,10 +74,16 @@ class InstanceNormOpConverter : public OpConverter {
bias_v
.
push_back
(
bias_d
[
i
]);
bias_v
.
push_back
(
bias_d
[
i
]);
}
}
plugin
::
InstanceNormPlugin
*
plugin
=
nvinfer1
::
IPluginV2
*
plugin
=
nullptr
;
new
plugin
::
InstanceNormPlugin
(
eps
,
scale_v
,
bias_v
);
if
(
engine_
->
with_dynamic_shape
())
{
plugin
->
getPluginType
();
plugin
=
new
plugin
::
InstanceNormPluginDynamic
(
eps
,
scale_v
,
bias_v
);
auto
*
layer
=
engine_
->
AddPlugin
(
&
input
,
1
,
plugin
);
}
else
{
plugin
=
new
plugin
::
InstanceNormPlugin
(
eps
,
scale_v
,
bias_v
);
}
std
::
vector
<
nvinfer1
::
ITensor
*>
instance_norm_inputs
{
input
};
auto
*
layer
=
engine_
->
network
()
->
addPluginV2
(
instance_norm_inputs
.
data
(),
instance_norm_inputs
.
size
(),
*
plugin
);
auto
output_name
=
op_desc
.
Output
(
"Y"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"Y"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"instance_norm"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
layer
,
"instance_norm"
,
{
output_name
},
test_mode
);
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
758fccfe
...
@@ -1501,10 +1501,6 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -1501,10 +1501,6 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if
(
op_type
==
"instance_norm"
)
{
if
(
op_type
==
"instance_norm"
)
{
if
(
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"trt instance_norm op does not support dynamic shape "
;
return
false
;
}
if
(
desc
.
Input
(
"X"
).
size
()
!=
1
)
{
if
(
desc
.
Input
(
"X"
).
size
()
!=
1
)
{
VLOG
(
3
)
<<
"input of instance_norm op converter should be 1, got "
VLOG
(
3
)
<<
"input of instance_norm op converter should be 1, got "
<<
desc
.
Input
(
"X"
).
size
();
<<
desc
.
Input
(
"X"
).
size
();
...
...
paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu
浏览文件 @
758fccfe
...
@@ -131,6 +131,115 @@ int InstanceNormPlugin::enqueue(int batch_size,
...
@@ -131,6 +131,115 @@ int InstanceNormPlugin::enqueue(int batch_size,
return
cudaGetLastError
()
!=
cudaSuccess
;
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
int
InstanceNormPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
nvinfer1
::
DimsExprs
InstanceNormPluginDynamic
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
assert
(
nbInputs
==
1
);
assert
(
index
<
this
->
getNbOutputs
());
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
return
output
;
}
bool
InstanceNormPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
assert
(
inOut
&&
pos
<
(
nbInputs
+
nbOutputs
));
assert
(
pos
==
0
||
pos
==
1
);
return
((
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
||
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
inOut
[
pos
].
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
)
&&
inOut
[
pos
].
type
==
inOut
[
0
].
type
);
}
int
InstanceNormPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
nvinfer1
::
Dims
input_dims
=
inputDesc
[
0
].
dims
;
int
n
=
input_dims
.
d
[
0
];
int
c
=
input_dims
.
d
[
1
];
int
h
=
input_dims
.
d
[
2
];
int
w
=
input_dims
.
d
[
3
];
scale_t
.
Resize
(
phi
::
make_ddim
({
n
,
c
}));
bias_t
.
Resize
(
phi
::
make_ddim
({
n
,
c
}));
int
device_id
;
cudaGetDevice
(
&
device_id
);
float
*
scale_d
=
scale_t
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
));
float
*
bias_d
=
bias_t
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
));
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
cudaMemcpyAsync
(
scale_d
+
i
*
c
,
scale_
.
data
(),
sizeof
(
float
)
*
c
,
cudaMemcpyHostToDevice
,
stream
);
cudaMemcpyAsync
(
bias_d
+
i
*
c
,
bias_
.
data
(),
sizeof
(
float
)
*
c
,
cudaMemcpyHostToDevice
,
stream
);
}
platform
::
dynload
::
cudnnSetTensor4dDescriptor
(
b_desc_
,
CUDNN_TENSOR_NCHW
,
CUDNN_DATA_FLOAT
,
1
,
n
*
c
,
1
,
1
);
cudnnDataType_t
cudnn_dtype
;
auto
data_type
=
inputDesc
[
0
].
type
;
convert_trt2cudnn_dtype
(
data_type
,
&
cudnn_dtype
);
platform
::
dynload
::
cudnnSetTensor4dDescriptor
(
x_desc_
,
CUDNN_TENSOR_NCHW
,
cudnn_dtype
,
1
,
n
*
c
,
h
,
w
);
platform
::
dynload
::
cudnnSetTensor4dDescriptor
(
y_desc_
,
CUDNN_TENSOR_NCHW
,
cudnn_dtype
,
1
,
n
*
c
,
h
,
w
);
float
alpha
=
1
;
float
beta
=
0
;
platform
::
dynload
::
cudnnSetStream
(
handle_
,
stream
);
void
const
*
x_ptr
=
inputs
[
0
];
void
*
y_ptr
=
outputs
[
0
];
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
handle_
,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
,
&
alpha
,
&
beta
,
x_desc_
,
x_ptr
,
y_desc_
,
y_ptr
,
b_desc_
,
scale_d
,
bias_d
,
1.
,
nullptr
,
nullptr
,
eps_
,
nullptr
,
nullptr
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
nvinfer1
::
DataType
InstanceNormPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
{
assert
(
inputTypes
&&
nbInputs
>
0
&&
index
==
0
);
return
inputTypes
[
0
];
}
void
InstanceNormPluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
}
// namespace plugin
}
// namespace plugin
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h
浏览文件 @
758fccfe
...
@@ -99,7 +99,7 @@ class InstanceNormPlugin : public PluginTensorRT {
...
@@ -99,7 +99,7 @@ class InstanceNormPlugin : public PluginTensorRT {
}
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"instance_norm
_plugin
"
;
return
"instance_norm"
;
}
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
...
@@ -125,7 +125,7 @@ class InstanceNormPlugin : public PluginTensorRT {
...
@@ -125,7 +125,7 @@ class InstanceNormPlugin : public PluginTensorRT {
class
InstanceNormPluginCreator
:
public
TensorRTPluginCreator
{
class
InstanceNormPluginCreator
:
public
TensorRTPluginCreator
{
public:
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"instance_norm
_plugin
"
;
return
"instance_norm"
;
}
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
...
@@ -137,7 +137,137 @@ class InstanceNormPluginCreator : public TensorRTPluginCreator {
...
@@ -137,7 +137,137 @@ class InstanceNormPluginCreator : public TensorRTPluginCreator {
return
new
InstanceNormPlugin
(
serial_data
,
serial_length
);
return
new
InstanceNormPlugin
(
serial_data
,
serial_length
);
}
}
};
};
class
InstanceNormPluginDynamic
:
public
DynamicPluginTensorRT
{
private:
float
eps_
;
std
::
vector
<
float
>
scale_
;
std
::
vector
<
float
>
bias_
;
phi
::
DenseTensor
scale_t
;
phi
::
DenseTensor
bias_t
;
cudnnHandle_t
handle_
;
cudnnTensorDescriptor_t
x_desc_
,
y_desc_
,
b_desc_
;
public:
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
SerializedSize
(
eps_
)
+
SerializedSize
(
scale_
)
+
SerializedSize
(
bias_
);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
scale_
);
SerializeValue
(
&
buffer
,
bias_
);
}
explicit
InstanceNormPluginDynamic
(
const
float
eps
,
const
std
::
vector
<
float
>
scale
,
const
std
::
vector
<
float
>
bias
)
:
eps_
(
eps
),
scale_
(
scale
),
bias_
(
bias
)
{
PADDLE_ENFORCE_EQ
(
scale
.
size
(),
bias
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The instanceNorm's scale and bias should be the "
"same size. Got scale size = %d, but bias size = %d"
,
scale
.
size
(),
bias
.
size
()));
platform
::
dynload
::
cudnnCreate
(
&
handle_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
x_desc_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
y_desc_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
b_desc_
);
}
// It was used for tensorrt deserialization.
// It should not be called by users.
InstanceNormPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
)
{
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
eps_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
scale_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
bias_
);
platform
::
dynload
::
cudnnCreate
(
&
handle_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
x_desc_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
y_desc_
);
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
b_desc_
);
}
~
InstanceNormPluginDynamic
()
{
platform
::
dynload
::
cudnnDestroy
(
handle_
);
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
x_desc_
);
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
y_desc_
);
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
b_desc_
);
}
int
initialize
()
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
return
new
InstanceNormPluginDynamic
(
eps_
,
scale_
,
bias_
);
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"instance_norm_dynamic"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
// NOLINT
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
};
class
InstanceNormPluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"instance_norm_dynamic"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
InstanceNormPluginDynamic
(
serial_data
,
serial_length
);
}
};
REGISTER_TRT_PLUGIN_V2
(
InstanceNormPluginCreator
);
REGISTER_TRT_PLUGIN_V2
(
InstanceNormPluginCreator
);
REGISTER_TRT_PLUGIN_V2
(
InstanceNormPluginDynamicCreator
);
}
// namespace plugin
}
// namespace plugin
}
// namespace tensorrt
}
// namespace tensorrt
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py
浏览文件 @
758fccfe
...
@@ -50,7 +50,13 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest):
...
@@ -50,7 +50,13 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest):
[
batch
,
16
,
32
,
64
],
[
batch
,
16
,
32
,
64
],
]:
]:
self
.
in_dim
=
len
(
shape_input
)
self
.
in_dim
=
len
(
shape_input
)
for
epsilon
in
[
0.0005
,
-
1
,
1
]:
for
epsilon
in
[
0.0005
,
-
1
,
1
,
0.000009999999747378752
,
0.00001
,
]:
dics
=
[{
"epsilon"
:
epsilon
}]
dics
=
[{
"epsilon"
:
epsilon
}]
ops_config
=
[
ops_config
=
[
{
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录