Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
babd26ee
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
babd26ee
编写于
12月 20, 2022
作者:
W
wenbin
提交者:
GitHub
12月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
groupnorm nhwc8 (#49160)
* gn nhwc8 * remove error
上级
6439e91d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
405 addition
and
18 deletion
+405
-18
paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h
.../inference/tensorrt/plugin/common/groupNormPluginCommon.h
+75
-0
paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu
...e/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu
+327
-18
paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h
...le/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h
+3
-0
未找到文件。
paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h
0 → 100644
浏览文件 @
babd26ee
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdint.h>
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
struct
GroupNormNHWCParams
{
// The output buffer. Layout NHWC.
__half
*
dst
;
// The input buffer. Layout NHWC.
__half
const
*
srcX
;
// The input buffer. Layout NHWC.
__half
const
*
srcY
;
// The gamma scaling factor.
void
const
*
gamma
;
// The beta term to add in GN.
void
const
*
beta
;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float
*
redBuffer
;
// The number of instances in the batch.
int32_t
n
;
// The height and width of each activation map.
int32_t
h
,
w
;
// The number of channels.
int32_t
c
;
// The number of groups.
int32_t
groups
;
// Do we apply the Swish activation function?
bool
withSwish
;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
int32_t
hw
,
hwPerBlock
;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t
cPerBlock
,
cPerGroup
;
// The precomputed stride between instances.
int32_t
hwc
;
// The inverse of hwc in floats (to compute mean/var).
float
invHWC
;
// The precomputed number of groups per block.
int32_t
groupsPerBlock
;
// epsilon, Constant for numerical stability
float
eps
;
};
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu
浏览文件 @
babd26ee
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -15,6 +17,7 @@ limitations under the License. */
...
@@ -15,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
#include <cub/cub.cuh>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
@@ -25,6 +28,262 @@ namespace tensorrt {
...
@@ -25,6 +28,262 @@ namespace tensorrt {
namespace
plugin
{
namespace
plugin
{
using
DataLayout
=
phi
::
DataLayout
;
using
DataLayout
=
phi
::
DataLayout
;
static
inline
int32_t
divUp
(
int32_t
m
,
int32_t
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
static
inline
__device__
__host__
float
sigmoid
(
float
x
)
{
return
1.
F
/
(
1.
F
+
expf
(
-
x
));
}
struct
GroupSums
{
// Is it the 1st element of the group?
int32_t
flag
;
// The sum.
float
sum
;
// The sum of squares.
float
sumSq
;
};
struct
GroupSumsOp
{
inline
__device__
GroupSums
operator
()(
GroupSums
const
&
a
,
GroupSums
const
&
b
)
{
GroupSums
dst
;
dst
.
sum
=
b
.
flag
?
b
.
sum
:
(
a
.
sum
+
b
.
sum
);
dst
.
sumSq
=
b
.
flag
?
b
.
sumSq
:
(
a
.
sumSq
+
b
.
sumSq
);
dst
.
flag
=
a
.
flag
+
b
.
flag
;
return
dst
;
}
};
static
int32_t
findMaxDivisor
(
int32_t
n
,
int32_t
maxAllowedDivisor
)
{
int32_t
maxDivisor
=
-
1
;
for
(
int32_t
i
=
1
;
i
<=
std
::
sqrt
(
n
);
i
++
)
{
if
(
n
%
i
==
0
)
{
int32_t
divisor1
=
n
/
i
;
int32_t
divisor2
=
i
;
if
(
divisor1
>
maxDivisor
&&
divisor1
<
maxAllowedDivisor
)
{
maxDivisor
=
divisor1
;
}
if
(
divisor2
>
maxDivisor
&&
divisor2
<
maxAllowedDivisor
)
{
maxDivisor
=
divisor2
;
}
}
}
return
maxDivisor
;
}
template
<
int
tTHREADS_PER_BLOCK
>
__global__
void
groupNormNHWCSumKernel
(
const
GroupNormNHWCParams
params
)
{
// The object in charge of doing the sums for the different blocks.
typedef
cub
::
BlockScan
<
GroupSums
,
tTHREADS_PER_BLOCK
>
BlockScan
;
// Allocate shared memory for BlockScan.
__shared__
typename
BlockScan
::
TempStorage
tempStorage
;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__
float2
smem
[
tTHREADS_PER_BLOCK
];
// The instance in the batch.
int32_t
ni
=
blockIdx
.
z
;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t
ci
=
blockIdx
.
x
*
params
.
cPerBlock
+
threadIdx
.
x
*
2
;
// The first activation loaded by that block.
int32_t
hwBegin
=
blockIdx
.
y
*
params
.
hwPerBlock
;
// The last activation loaded by that block.
int32_t
hwEnd
=
min
(
hwBegin
+
params
.
hwPerBlock
,
params
.
hw
);
// The sums.
float
sum
=
0.
F
;
float
sumSq
=
0.
F
;
// Iterate over the activations to compute the sums.
for
(
int32_t
hwi
=
hwBegin
;
hwi
<
hwEnd
;
++
hwi
)
{
// The offset.
int64_t
offset
=
static_cast
<
int64_t
>
(
ni
)
*
params
.
hwc
+
static_cast
<
int64_t
>
(
hwi
)
*
params
.
c
+
ci
;
// Fetch two channels per thread.
__half2
h2
(
0
,
0
);
if
(
ci
<
params
.
c
)
{
h2
=
*
reinterpret_cast
<
__half2
const
*>
(
&
params
.
srcX
[
offset
]);
}
// Extract the two half values.
float2
f2
=
__half22float2
(
h2
);
// Update the sum.
sum
+=
f2
.
x
+
f2
.
y
;
// Update the sum of squares.
sumSq
+=
f2
.
x
*
f2
.
x
+
f2
.
y
*
f2
.
y
;
}
// The group that thread works on and the channel in the group (modulus).
int32_t
gi
=
threadIdx
.
x
*
2
/
params
.
cPerGroup
;
int32_t
cj
=
threadIdx
.
x
*
2
-
params
.
cPerGroup
*
gi
;
// The data for the summations.
GroupSums
inp
{
cj
==
0
?
1
:
0
,
sum
,
sumSq
};
// Do the segmented scan.
GroupSums
out
;
BlockScan
(
tempStorage
).
InclusiveScan
(
inp
,
out
,
GroupSumsOp
());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
// 2 channels per thread
if
(
cj
==
params
.
cPerGroup
-
2
)
{
smem
[
gi
]
=
make_float2
(
out
.
sum
,
out
.
sumSq
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// The global group index.
int32_t
gj
=
blockIdx
.
x
*
params
.
groupsPerBlock
+
threadIdx
.
x
;
// Threads that have nothing left to do, exit.
if
(
threadIdx
.
x
>=
params
.
groupsPerBlock
||
gj
>=
params
.
groups
)
{
return
;
}
// The first threads (those storing to global memory, load the values).
float2
sums
=
smem
[
threadIdx
.
x
];
// Store to global memory.
atomicAdd
(
&
params
.
redBuffer
[(
2
*
ni
+
0
)
*
params
.
groups
+
gj
],
sums
.
x
);
atomicAdd
(
&
params
.
redBuffer
[(
2
*
ni
+
1
)
*
params
.
groups
+
gj
],
sums
.
y
);
}
void
groupNormNHWCSum
(
const
GroupNormNHWCParams
&
params
,
cudaStream_t
stream
)
{
dim3
grid
;
// The number of blocks to compute all the channels.
grid
.
x
=
params
.
c
/
params
.
cPerBlock
;
// The number of blocks to compute all the activations in a given instance.
grid
.
y
=
divUp
(
params
.
hw
,
params
.
hwPerBlock
);
// The number of instances.
grid
.
z
=
params
.
n
;
switch
(
params
.
cPerBlock
)
{
case
320
:
groupNormNHWCSumKernel
<
160
><<<
grid
,
160
,
0
,
stream
>>>
(
params
);
break
;
case
480
:
groupNormNHWCSumKernel
<
256
><<<
grid
,
256
,
0
,
stream
>>>
(
params
);
break
;
case
256
:
groupNormNHWCSumKernel
<
128
><<<
grid
,
128
,
0
,
stream
>>>
(
params
);
break
;
case
128
:
groupNormNHWCSumKernel
<
64
><<<
grid
,
64
,
0
,
stream
>>>
(
params
);
break
;
}
}
template
<
int
tTHREADS_PER_BLOCK
>
__global__
void
groupNormNHWCScaleKernel
(
const
GroupNormNHWCParams
params
)
{
// The instance in the batch.
int32_t
ni
=
blockIdx
.
z
;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t
ci
=
blockIdx
.
x
*
params
.
cPerBlock
+
threadIdx
.
x
*
2
;
// The group that thread works on and the channel in the group (modulus).
int32_t
gi
=
ci
/
params
.
cPerGroup
;
// Load the sum and sum of squares for the group.
float
sum
=
0.
F
,
sumSq
=
0.
F
;
if
(
gi
<
params
.
groups
)
{
sum
=
params
.
redBuffer
[(
2
*
ni
+
0
)
*
params
.
groups
+
gi
];
sumSq
=
params
.
redBuffer
[(
2
*
ni
+
1
)
*
params
.
groups
+
gi
];
}
// Load gamma/beta.
float2
gammaF2
,
betaF2
;
if
(
ci
<
params
.
c
)
{
gammaF2
=
__half22float2
(
*
reinterpret_cast
<
half2
const
*>
(
reinterpret_cast
<
half
const
*>
(
params
.
gamma
)
+
ci
));
betaF2
=
__half22float2
(
*
reinterpret_cast
<
half2
const
*>
(
reinterpret_cast
<
half
const
*>
(
params
.
beta
)
+
ci
));
}
// Compute the mean.
float
mean
=
sum
*
params
.
invHWC
;
// Compute the variance.
float
var
=
sumSq
*
params
.
invHWC
-
(
mean
*
mean
);
// Compute the inverse of the stddev.
float
invStdDev
=
rsqrtf
(
var
+
params
.
eps
);
// The first activation loaded by that block.
int32_t
hwBegin
=
blockIdx
.
y
*
params
.
hwPerBlock
;
// The last activation loaded by that block.
int32_t
hwEnd
=
min
(
hwBegin
+
params
.
hwPerBlock
,
params
.
hw
);
// Iterate over the activations to compute the sums.
for
(
int32_t
hwi
=
hwBegin
;
hwi
<
hwEnd
;
++
hwi
)
{
// The src/dst offset.
int64_t
offset
=
(
int64_t
)
ni
*
params
.
hwc
+
hwi
*
params
.
c
+
ci
;
// Fetch two channels per thread.
__half2
h2
(
0
,
0
);
if
(
ci
<
params
.
c
)
{
h2
=
*
reinterpret_cast
<
__half2
const
*>
(
&
params
.
srcX
[
offset
]);
}
// Extract the two half values.
float2
f2
=
__half22float2
(
h2
);
// Normalize the channels.
f2
.
x
=
(
f2
.
x
-
mean
)
*
invStdDev
;
f2
.
y
=
(
f2
.
y
-
mean
)
*
invStdDev
;
// Scale by gamma and add beta.
f2
.
x
=
gammaF2
.
x
*
f2
.
x
+
betaF2
.
x
;
f2
.
y
=
gammaF2
.
y
*
f2
.
y
+
betaF2
.
y
;
// Apply Swish if needed.
if
(
params
.
withSwish
)
{
f2
.
x
=
f2
.
x
*
sigmoid
(
f2
.
x
);
f2
.
y
=
f2
.
y
*
sigmoid
(
f2
.
y
);
}
// Store the scaled values.
if
(
ci
<
params
.
c
)
{
*
reinterpret_cast
<
__half2
*>
(
&
params
.
dst
[
offset
])
=
__float22half2_rn
(
f2
);
}
}
}
void
groupNormNHWCScale
(
const
GroupNormNHWCParams
&
params
,
cudaStream_t
stream
)
{
dim3
grid
;
// The number of blocks to compute all the channels.
grid
.
x
=
params
.
c
/
params
.
cPerBlock
;
// The number of blocks to compute all the activations in a given instance.
grid
.
y
=
divUp
(
params
.
hw
,
params
.
hwPerBlock
);
// The number of instances.
grid
.
z
=
params
.
n
;
switch
(
params
.
cPerBlock
)
{
case
320
:
groupNormNHWCScaleKernel
<
160
><<<
grid
,
160
,
0
,
stream
>>>
(
params
);
break
;
case
480
:
groupNormNHWCScaleKernel
<
256
><<<
grid
,
256
,
0
,
stream
>>>
(
params
);
break
;
case
256
:
groupNormNHWCScaleKernel
<
128
><<<
grid
,
128
,
0
,
stream
>>>
(
params
);
break
;
case
128
:
groupNormNHWCScaleKernel
<
64
><<<
grid
,
64
,
0
,
stream
>>>
(
params
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The function groupNormNHWCScale of "
"GroupNorm TRT Plugin encounter error"
));
}
}
int
GroupNormPlugin
::
initialize
()
TRT_NOEXCEPT
{
int
GroupNormPlugin
::
initialize
()
TRT_NOEXCEPT
{
if
(
!
with_fp16_
)
{
if
(
!
with_fp16_
)
{
// if use fp32
// if use fp32
...
@@ -188,7 +447,8 @@ bool GroupNormPluginDynamic::supportsFormatCombination(
...
@@ -188,7 +447,8 @@ bool GroupNormPluginDynamic::supportsFormatCombination(
if
(
pos
==
0
)
{
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
if
(
with_fp16_
)
{
return
((
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
return
((
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
));
(
in
.
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
||
in
.
format
==
nvinfer1
::
PluginFormat
::
kHWC8
));
}
else
{
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
...
@@ -275,9 +535,7 @@ int GroupNormPluginDynamic::enqueue(
...
@@ -275,9 +535,7 @@ int GroupNormPluginDynamic::enqueue(
int
C
=
input_shape
[
1
];
int
C
=
input_shape
[
1
];
int
image_size
=
input_shape
[
2
]
*
input_shape
[
3
];
int
image_size
=
input_shape
[
2
]
*
input_shape
[
3
];
int
batchSize
=
input_shape
[
0
];
int
batchSize
=
input_shape
[
0
];
std
::
vector
<
int64_t
>
batched_mean_shape
=
{
batchSize
*
mean_shape_
[
0
]};
std
::
vector
<
int64_t
>
batched_variance_shape
=
{
batchSize
*
variance_shape_
[
0
]};
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
C
,
C
,
scale_
.
size
(),
scale_
.
size
(),
...
@@ -320,7 +578,7 @@ int GroupNormPluginDynamic::enqueue(
...
@@ -320,7 +578,7 @@ int GroupNormPluginDynamic::enqueue(
VLOG
(
1
)
<<
"TRT Plugin DataType selected. GroupNorm-->fp16"
;
VLOG
(
1
)
<<
"TRT Plugin DataType selected. GroupNorm-->fp16"
;
const
half
*
input
=
reinterpret_cast
<
const
half
*>
(
inputs
[
0
]);
const
half
*
input
=
reinterpret_cast
<
const
half
*>
(
inputs
[
0
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
if
(
input_desc
[
0
].
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
)
{
phi
::
GroupNormDirectCUDAFunctor
<
half
,
float
>
group_norm
;
phi
::
GroupNormDirectCUDAFunctor
<
half
,
float
>
group_norm
;
group_norm
(
stream
,
group_norm
(
stream
,
input
,
input
,
...
@@ -334,11 +592,62 @@ int GroupNormPluginDynamic::enqueue(
...
@@ -334,11 +592,62 @@ int GroupNormPluginDynamic::enqueue(
mean_d
,
mean_d
,
variance_d
,
variance_d
,
DataLayout
::
kNCHW
);
DataLayout
::
kNCHW
);
}
else
if
(
input_desc
[
0
].
format
==
nvinfer1
::
PluginFormat
::
kHWC8
)
{
int32_t
cPerBlock
=
320
;
int32_t
maxBlocksPerHW
=
1024
;
switch
(
input_desc
[
0
].
dims
.
d
[
1
])
{
case
960
:
case
1920
:
cPerBlock
=
480
;
break
;
case
512
:
case
256
:
cPerBlock
=
256
;
break
;
case
128
:
cPerBlock
=
128
;
break
;
default:
cPerBlock
=
320
;
}
params_
.
withSwish
=
false
;
params_
.
dst
=
static_cast
<
half
*>
(
outputs
[
0
]);
params_
.
srcX
=
static_cast
<
half
const
*>
(
inputs
[
0
]);
params_
.
gamma
=
scale_gpu_
;
params_
.
beta
=
bias_gpu_
;
params_
.
redBuffer
=
static_cast
<
float
*>
(
workspace
);
params_
.
n
=
input_desc
[
0
].
dims
.
d
[
0
];
params_
.
h
=
input_desc
[
0
].
dims
.
d
[
2
];
params_
.
w
=
input_desc
[
0
].
dims
.
d
[
3
];
params_
.
c
=
input_desc
[
0
].
dims
.
d
[
1
];
params_
.
groups
=
groups_
;
params_
.
hw
=
params_
.
h
*
params_
.
w
;
const
int32_t
blocksPerHW
=
findMaxDivisor
(
params_
.
hw
,
maxBlocksPerHW
);
params_
.
hwPerBlock
=
divUp
(
params_
.
hw
,
blocksPerHW
);
params_
.
cPerBlock
=
cPerBlock
;
params_
.
cPerGroup
=
params_
.
c
/
params_
.
groups
;
params_
.
hwc
=
params_
.
hw
*
params_
.
c
;
params_
.
invHWC
=
1.
F
/
static_cast
<
float
>
(
params_
.
hw
*
params_
.
cPerGroup
);
params_
.
groupsPerBlock
=
cPerBlock
/
params_
.
cPerGroup
;
params_
.
eps
=
eps_
;
cudaMemsetAsync
(
params_
.
redBuffer
,
0
,
2
*
sizeof
(
float
)
*
params_
.
n
*
groups_
,
stream
);
groupNormNHWCSum
(
params_
,
stream
);
groupNormNHWCScale
(
params_
,
stream
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Groupnorm TRT Plugin's only support nchw or nhwc8 input"
));
}
}
else
{
}
else
{
// input not float
// input not float
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Groupnorm TRT Plugin's only support fp32 input"
));
"The Groupnorm TRT Plugin's only support fp32
or fp16
input"
));
}
}
return
cudaGetLastError
()
!=
cudaSuccess
;
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
...
...
paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h
浏览文件 @
babd26ee
...
@@ -21,7 +21,9 @@ limitations under the License. */
...
@@ -21,7 +21,9 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
...
@@ -274,6 +276,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -274,6 +276,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
float
eps_
;
float
eps_
;
std
::
vector
<
int64_t
>
mean_shape_
;
std
::
vector
<
int64_t
>
mean_shape_
;
std
::
vector
<
int64_t
>
variance_shape_
;
std
::
vector
<
int64_t
>
variance_shape_
;
GroupNormNHWCParams
params_
;
bool
with_fp16_
;
bool
with_fp16_
;
};
};
class
GroupNormPluginDynamicCreator
:
public
TensorRTPluginCreator
{
class
GroupNormPluginDynamicCreator
:
public
TensorRTPluginCreator
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录