Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6a87f61b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a87f61b
编写于
7月 19, 2020
作者:
R
ReeseWang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add trt stack op, test=develop
上级
f1b1c753
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
359 addition
and
28 deletion
+359
-28
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/stack_op.cc
paddle/fluid/inference/tensorrt/convert/stack_op.cc
+81
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+28
-28
paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu
+154
-0
paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h
+95
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
6a87f61b
...
@@ -1035,4 +1035,5 @@ USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
...
@@ -1035,4 +1035,5 @@ USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
USE_TRT_CONVERTER
(
skip_layernorm
);
USE_TRT_CONVERTER
(
skip_layernorm
);
USE_TRT_CONVERTER
(
slice
);
USE_TRT_CONVERTER
(
slice
);
USE_TRT_CONVERTER
(
scale
);
USE_TRT_CONVERTER
(
scale
);
USE_TRT_CONVERTER
(
stack
);
#endif
#endif
paddle/fluid/inference/tensorrt/convert/stack_op.cc
0 → 100644
浏览文件 @
6a87f61b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
#include <iostream>
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* Stack converter from fluid to tensorRT.
*/
class
StackOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert fluid stack op to tensorrt stack layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
input
=
op_desc
.
Input
(
"X"
);
int
input_num
=
input
.
size
();
nvinfer1
::
ITensor
**
inputs
=
(
nvinfer1
::
ITensor
**
)
malloc
(
input_num
*
sizeof
(
nvinfer1
::
ITensor
*
));
for
(
int
i
=
0
;
i
<
input_num
;
++
i
)
{
inputs
[
i
]
=
engine_
->
GetITensor
(
input
[
i
]);
}
auto
idim
=
inputs
[
0
]
->
getDimensions
();
std
::
cerr
<<
"Stack input: "
<<
idim
.
nbDims
<<
" "
<<
idim
.
d
[
0
]
<<
" "
<<
idim
.
d
[
1
]
<<
" "
<<
idim
.
d
[
2
]
<<
std
::
endl
;
int
axis
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"axis"
));
if
(
axis
<
0
)
{
axis
=
axis
+
inputs
[
0
]
->
getDimensions
().
nbDims
+
1
;
}
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
#if IS_TRT_VERSION_GE(6000)
plugin
::
StackPluginDynamic
*
plugin
=
new
plugin
::
StackPluginDynamic
(
axis
,
input_num
);
layer
=
engine_
->
AddPluginV2
(
inputs
,
input_num
,
plugin
);
assert
(
layer
!=
nullptr
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.
\n
"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."
));
}
auto
output_name
=
op_desc
.
Output
(
"Y"
).
front
();
RreplenishLayerAndOutput
(
layer
,
"stack"
,
{
output_name
},
test_mode
);
free
(
inputs
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
stack
,
StackOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
6a87f61b
...
@@ -55,34 +55,34 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -55,34 +55,34 @@ struct SimpleOpTypeSetTeller : public Teller {
"fc"
,
"fc"
,
"relu6"
,
"relu6"
,
"concat"
};
"concat"
};
std
::
unordered_set
<
std
::
string
>
teller_set
{
std
::
unordered_set
<
std
::
string
>
teller_set
{
"mul"
,
"mul
"
,
"conv2d
"
,
"conv
2d"
,
"pool
2d"
,
"pool2d
"
,
"relu
"
,
"relu
"
,
"softmax
"
,
"softmax
"
,
"sigmoid
"
,
"sigmoid
"
,
"hard_swish
"
,
"hard_swish
"
,
"depthwise_conv2d
"
,
"depthwise_conv2d
"
,
"batch_norm
"
,
"batch_norm
"
,
"concat
"
,
"concat
"
,
"tanh
"
,
"tanh
"
,
"pad
"
,
"pa
d"
,
"elementwise_ad
d"
,
"elementwise_add
"
,
"elementwise_mul
"
,
"elementwise_mul
"
,
"dropout
"
,
"dropout
"
,
"prelu
"
,
"prelu
"
,
"conv2d_transpose
"
,
"conv2d_transpose
"
,
"leaky_relu
"
,
"leaky_relu
"
,
"fc
"
,
"fc
"
,
"shuffle_channel
"
,
"shuffle_channel
"
,
"swish
"
,
"swish
"
,
"split
"
,
"split
"
,
"instance_norm
"
,
"instance_norm
"
,
"gelu
"
,
"gelu
"
,
"layer_norm
"
,
"layer_norm
"
,
"scale
"
,
"scal
e"
,
"slic
e"
,
};
"stack"
};
};
};
bool
OpTeller
::
Tell
(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
,
bool
OpTeller
::
Tell
(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
,
...
...
paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu
0 → 100644
浏览文件 @
6a87f61b
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
size_t
StackPluginDynamic
::
getSerializationSize
()
const
{
return
0
;
}
void
StackPluginDynamic
::
serialize
(
void
*
buffer
)
const
{}
nvinfer1
::
DimsExprs
StackPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
output
.
nbDims
=
inputs
[
0
].
nbDims
+
1
;
for
(
int
i
=
inputs
[
0
].
nbDims
;
i
>
axis_
;
--
i
)
{
output
.
d
[
i
]
=
inputs
[
0
].
d
[
i
-
1
];
}
output
.
d
[
axis_
]
=
expr_builder
.
constant
(
nb_inputs
);
return
output
;
}
bool
StackPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of stack plugin should not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
#ifdef SUPPORTS_CUDA_FP16
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#else
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#endif
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
// output
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
nvinfer1
::
DataType
StackPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The index should be equal to 0"
));
return
input_types
[
0
];
}
template
<
typename
T
>
__global__
void
StackKernel
(
const
T
*
const
*
input
,
T
*
output
,
int
num_stack
,
int
base_unit
)
{
int
stack_id
=
blockIdx
.
x
;
int
lead_id
=
blockIdx
.
y
;
for
(
int
i
=
threadIdx
.
x
;
i
<
base_unit
;
i
+=
blockDim
.
x
)
{
output
[
lead_id
*
num_stack
*
base_unit
+
stack_id
*
base_unit
+
i
]
=
input
[
stack_id
][
lead_id
*
base_unit
+
i
];
}
}
int
StackPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
input_dims
=
input_desc
[
0
].
dims
;
// (batch, seq, seq)
auto
out_dims
=
output_desc
[
0
].
dims
;
// (batch, num_head, seq, seq)
auto
out_num_dims
=
out_dims
.
nbDims
;
int
base_unit
=
1
;
for
(
int
i
=
axis_
+
1
;
i
<
out_num_dims
;
++
i
)
{
PADDLE_ENFORCE_GT
(
out_dims
.
d
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"Input dimensions should be greater than 0"
));
base_unit
*=
out_dims
.
d
[
i
];
}
int
lead_unit
=
1
;
for
(
int
i
=
0
;
i
<
axis_
;
++
i
)
{
PADDLE_ENFORCE_GT
(
out_dims
.
d
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"Input dimensions should be greater than 0"
));
lead_unit
*=
out_dims
.
d
[
i
];
}
cudaMemcpyAsync
(
reinterpret_cast
<
void
*>
(
in_ptr_gpu_
),
reinterpret_cast
<
const
void
*
const
>
(
inputs
),
sizeof
(
void
*
)
*
out_dims
.
d
[
axis_
],
cudaMemcpyHostToDevice
,
stream
);
int
num_stacks
=
out_dims
.
d
[
axis_
];
dim3
num_blocks
(
num_stacks
,
lead_unit
);
int
num_threads
=
256
;
auto
infer_type
=
input_desc
[
0
].
type
;
if
(
infer_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
StackKernel
<
float
><<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
reinterpret_cast
<
const
float
*
const
*>
(
in_ptr_gpu_
),
output
,
num_stacks
,
base_unit
);
}
else
if
(
infer_type
==
nvinfer1
::
DataType
::
kHALF
)
{
#ifdef SUPPORTS_CUDA_FP16
__half
*
output
=
static_cast
<
__half
*>
(
outputs
[
0
]);
StackKernel
<
__half
><<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
reinterpret_cast
<
const
__half
*
const
*>
(
in_ptr_gpu_
),
output
,
num_stacks
,
base_unit
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The cuda archs you specific should greater than 600."
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Stack TRT Plugin's input type only "
"support float or half currently."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h
0 → 100644
浏览文件 @
6a87f61b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
StackPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
StackPluginDynamic
(
int
axis
,
int
num_stack
)
:
axis_
(
axis
),
num_stack_
(
num_stack
)
{
int
device_id
;
cudaGetDevice
(
&
device_id
);
in_ptr_tensor_
.
Resize
({
num_stack
});
in_ptr_gpu_
=
in_ptr_tensor_
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id
));
}
StackPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
)
{}
~
StackPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
StackPluginDynamic
(
axis_
,
num_stack_
);
}
const
char
*
getPluginType
()
const
override
{
return
"stack_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
{
return
0
;
}
size_t
getSerializationSize
()
const
override
;
void
serialize
(
void
*
buffer
)
const
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
override
{}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
override
;
void
destroy
()
override
{
delete
this
;
}
private:
int
axis_
;
int
num_stack_
;
framework
::
Tensor
in_ptr_tensor_
;
int64_t
*
in_ptr_gpu_
;
};
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录