Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c73977af
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c73977af
编写于
6月 07, 2018
作者:
L
Luo Tao
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into trt
上级
e8e8ad04
ca2d6d3c
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
442 addition
and
96 deletion
+442
-96
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+17
-13
paddle/fluid/framework/scope.h
paddle/fluid/framework/scope.h
+14
-3
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+9
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+7
-1
paddle/fluid/inference/tensorrt/convert/activation_op.cc
paddle/fluid/inference/tensorrt/convert/activation_op.cc
+3
-2
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+3
-3
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+6
-4
paddle/fluid/inference/tensorrt/convert/mul_op.cc
paddle/fluid/inference/tensorrt/convert/mul_op.cc
+10
-5
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+29
-11
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+3
-2
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+8
-1
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+27
-1
paddle/fluid/inference/tests/book/test_inference_nlp.cc
paddle/fluid/inference/tests/book/test_inference_nlp.cc
+26
-31
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-0
paddle/fluid/operators/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt_engine_op.cc
+77
-5
paddle/fluid/operators/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt_engine_op.h
+19
-11
paddle/fluid/operators/tensorrt_engine_op_test.cc
paddle/fluid/operators/tensorrt_engine_op_test.cc
+152
-0
paddle/fluid/platform/assert.h
paddle/fluid/platform/assert.h
+5
-2
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+22
-0
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+1
-1
未找到文件。
paddle/fluid/framework/scope.cc
浏览文件 @
c73977af
...
@@ -34,13 +34,7 @@ DEFINE_bool(
...
@@ -34,13 +34,7 @@ DEFINE_bool(
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
Scope
::~
Scope
()
{
Scope
::~
Scope
()
{
DropKids
();
}
DropKids
();
for
(
auto
&
kv
:
vars_
)
{
VLOG
(
3
)
<<
"Destroy variable "
<<
kv
.
first
;
delete
kv
.
second
;
}
}
Scope
&
Scope
::
NewScope
()
const
{
Scope
&
Scope
::
NewScope
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
...
@@ -49,10 +43,13 @@ Scope& Scope::NewScope() const {
...
@@ -49,10 +43,13 @@ Scope& Scope::NewScope() const {
}
}
Variable
*
Scope
::
Var
(
const
std
::
string
&
name
)
{
Variable
*
Scope
::
Var
(
const
std
::
string
&
name
)
{
// acquire the lock when new var under this scope
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
*
v
=
FindVarLocally
(
name
);
auto
*
v
=
FindVarLocally
(
name
);
if
(
v
!=
nullptr
)
return
v
;
if
(
v
!=
nullptr
)
return
v
;
v
=
new
Variable
();
v
=
new
Variable
();
vars_
[
name
]
=
v
;
vars_
[
name
]
.
reset
(
v
)
;
VLOG
(
3
)
<<
"Create variable "
<<
name
;
VLOG
(
3
)
<<
"Create variable "
<<
name
;
v
->
name_
=
&
(
vars_
.
find
(
name
)
->
first
);
v
->
name_
=
&
(
vars_
.
find
(
name
)
->
first
);
return
v
;
return
v
;
...
@@ -67,22 +64,29 @@ Variable* Scope::Var(std::string* name) {
...
@@ -67,22 +64,29 @@ Variable* Scope::Var(std::string* name) {
}
}
Variable
*
Scope
::
FindVar
(
const
std
::
string
&
name
)
const
{
Variable
*
Scope
::
FindVar
(
const
std
::
string
&
name
)
const
{
// acquire the lock when find var
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
FindVarInternal
(
name
);
}
Variable
*
Scope
::
FindVarInternal
(
const
std
::
string
&
name
)
const
{
auto
var
=
FindVarLocally
(
name
);
auto
var
=
FindVarLocally
(
name
);
if
(
var
!=
nullptr
)
{
if
(
var
!=
nullptr
)
{
return
var
;
return
var
;
}
}
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindVar
(
name
);
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindVar
Internal
(
name
);
}
}
const
Scope
*
Scope
::
FindScope
(
const
Variable
*
var
)
const
{
const
Scope
*
Scope
::
FindScope
(
const
Variable
*
var
)
const
{
for
(
auto
&
kv
:
vars_
)
{
for
(
auto
&
kv
:
vars_
)
{
if
(
kv
.
second
==
var
)
{
if
(
kv
.
second
.
get
()
==
var
)
{
return
this
;
return
this
;
}
}
}
}
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindScope
(
var
);
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindScope
(
var
);
}
}
void
Scope
::
DropKids
()
{
void
Scope
::
DropKids
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
for
(
Scope
*
s
:
kids_
)
delete
s
;
for
(
Scope
*
s
:
kids_
)
delete
s
;
kids_
.
clear
();
kids_
.
clear
();
}
}
...
@@ -110,10 +114,10 @@ void Scope::DeleteScope(Scope* scope) const {
...
@@ -110,10 +114,10 @@ void Scope::DeleteScope(Scope* scope) const {
}
}
void
Scope
::
EraseVars
(
const
std
::
vector
<
std
::
string
>&
var_names
)
{
void
Scope
::
EraseVars
(
const
std
::
vector
<
std
::
string
>&
var_names
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
set
<
std
::
string
>
var_set
(
var_names
.
begin
(),
var_names
.
end
());
std
::
set
<
std
::
string
>
var_set
(
var_names
.
begin
(),
var_names
.
end
());
for
(
auto
it
=
vars_
.
begin
();
it
!=
vars_
.
end
();)
{
for
(
auto
it
=
vars_
.
begin
();
it
!=
vars_
.
end
();)
{
if
(
var_set
.
find
(
it
->
first
)
!=
var_set
.
end
())
{
if
(
var_set
.
find
(
it
->
first
)
!=
var_set
.
end
())
{
delete
it
->
second
;
it
=
vars_
.
erase
(
it
);
it
=
vars_
.
erase
(
it
);
}
else
{
}
else
{
++
it
;
++
it
;
...
@@ -129,7 +133,7 @@ void Scope::Rename(const std::string& origin_name,
...
@@ -129,7 +133,7 @@ void Scope::Rename(const std::string& origin_name,
auto
new_it
=
vars_
.
find
(
new_name
);
auto
new_it
=
vars_
.
find
(
new_name
);
PADDLE_ENFORCE
(
new_it
==
vars_
.
end
(),
PADDLE_ENFORCE
(
new_it
==
vars_
.
end
(),
"The variable with name %s is already in the scope"
,
new_name
);
"The variable with name %s is already in the scope"
,
new_name
);
vars_
[
new_name
]
=
origin_it
->
second
;
vars_
[
new_name
]
.
reset
(
origin_it
->
second
.
release
())
;
vars_
.
erase
(
origin_it
);
vars_
.
erase
(
origin_it
);
}
}
...
@@ -141,7 +145,7 @@ std::string Scope::Rename(const std::string& origin_name) const {
...
@@ -141,7 +145,7 @@ std::string Scope::Rename(const std::string& origin_name) const {
Variable
*
Scope
::
FindVarLocally
(
const
std
::
string
&
name
)
const
{
Variable
*
Scope
::
FindVarLocally
(
const
std
::
string
&
name
)
const
{
auto
it
=
vars_
.
find
(
name
);
auto
it
=
vars_
.
find
(
name
);
if
(
it
!=
vars_
.
end
())
return
it
->
second
;
if
(
it
!=
vars_
.
end
())
return
it
->
second
.
get
()
;
return
nullptr
;
return
nullptr
;
}
}
...
...
paddle/fluid/framework/scope.h
浏览文件 @
c73977af
...
@@ -47,15 +47,18 @@ class Scope {
...
@@ -47,15 +47,18 @@ class Scope {
Scope
&
NewScope
()
const
;
Scope
&
NewScope
()
const
;
/// Create a variable with given name if it doesn't exist.
/// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable.
Variable
*
Var
(
const
std
::
string
&
name
);
Variable
*
Var
(
const
std
::
string
&
name
);
/// Create a variable with a scope-unique name.
/// Create a variable with a scope-unique name.
/// Caller doesn't own the returned Variable.
Variable
*
Var
(
std
::
string
*
name
=
nullptr
);
Variable
*
Var
(
std
::
string
*
name
=
nullptr
);
void
EraseVars
(
const
std
::
vector
<
std
::
string
>&
var_names
);
void
EraseVars
(
const
std
::
vector
<
std
::
string
>&
var_names
);
/// Find a variable in the scope or any of its ancestors. Returns
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
/// nullptr if cannot find.
/// Caller doesn't own the returned Variable.
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
const
Scope
*
parent
()
const
{
return
parent_
;
}
const
Scope
*
parent
()
const
{
return
parent_
;
}
...
@@ -78,13 +81,21 @@ class Scope {
...
@@ -78,13 +81,21 @@ class Scope {
// Rename variable to a new name and return the new name
// Rename variable to a new name and return the new name
std
::
string
Rename
(
const
std
::
string
&
origin_name
)
const
;
std
::
string
Rename
(
const
std
::
string
&
origin_name
)
const
;
Variable
*
FindVarLocally
(
const
std
::
string
&
name
)
const
;
private:
private:
// Call Scope::NewScope for a sub-scope.
// Call Scope::NewScope for a sub-scope.
explicit
Scope
(
Scope
const
*
parent
)
:
parent_
(
parent
)
{}
explicit
Scope
(
Scope
const
*
parent
)
:
parent_
(
parent
)
{}
mutable
std
::
unordered_map
<
std
::
string
,
Variable
*>
vars_
;
// Called by FindVar recursively.
// Caller doesn't own the returned Variable.
Variable
*
FindVarInternal
(
const
std
::
string
&
name
)
const
;
// Called by FindVarInternal and Var.
// Caller doesn't own the returned Variable.
Variable
*
FindVarLocally
(
const
std
::
string
&
name
)
const
;
mutable
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
vars_
;
// Scope in `kids_` are owned by this class.
mutable
std
::
list
<
Scope
*>
kids_
;
mutable
std
::
list
<
Scope
*>
kids_
;
Scope
const
*
parent_
{
nullptr
};
Scope
const
*
parent_
{
nullptr
};
...
...
paddle/fluid/inference/analysis/helper.h
浏览文件 @
c73977af
...
@@ -18,6 +18,8 @@ limitations under the License. */
...
@@ -18,6 +18,8 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -107,6 +109,13 @@ class OrderedRegistry {
...
@@ -107,6 +109,13 @@ class OrderedRegistry {
std
::
vector
<
std
::
unique_ptr
<
T
>>
data_
;
std
::
vector
<
std
::
unique_ptr
<
T
>>
data_
;
};
};
template
<
typename
T
>
T
&
GetFromScope
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
name
)
{
framework
::
Variable
*
var
=
scope
.
FindVar
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
);
return
*
var
->
GetMutable
<
T
>
();
}
}
// namespace analysis
}
// namespace analysis
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
c73977af
# Add TRT tests
# Add TRT tests
nv_test
(
test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine
)
nv_library
(
tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc
DEPS tensorrt_engine mul_op
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine tensorrt_converter
)
nv_test
(
test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor
)
nv_test
(
test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor
)
nv_test
(
test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
nv_test
(
test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine mul_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine mul_op SERIAL
)
...
...
paddle/fluid/inference/tensorrt/convert/activation_op.cc
浏览文件 @
c73977af
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -37,8 +38,8 @@ class ReluOpConverter : public OpConverter {
...
@@ -37,8 +38,8 @@ class ReluOpConverter : public OpConverter {
}
}
};
};
REGISTER_TRT_OP_CONVERTER
(
relu
,
ReluOpConverter
);
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
relu
,
ReluOpConverter
);
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
c73977af
...
@@ -22,14 +22,14 @@ class Conv2dOpConverter : public OpConverter {
...
@@ -22,14 +22,14 @@ class Conv2dOpConverter : public OpConverter {
public:
public:
Conv2dOpConverter
()
{}
Conv2dOpConverter
()
{}
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
LOG
(
INFO
)
LOG
(
INFO
)
<<
"convert a fluid conv2d op to tensorrt conv layer without bias"
;
<<
"convert a fluid conv2d op to tensorrt conv layer without bias"
;
}
}
};
};
REGISTER_TRT_OP_CONVERTER
(
conv2d
,
Conv2dOpConverter
);
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
conv2d
,
Conv2dOpConverter
);
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
c73977af
...
@@ -56,7 +56,7 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
...
@@ -56,7 +56,7 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
class
FcOpConverter
:
public
OpConverter
{
class
FcOpConverter
:
public
OpConverter
{
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a fluid fc op to tensorrt fc layer without bias"
;
VLOG
(
4
)
<<
"convert a fluid fc op to tensorrt fc layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
...
@@ -106,14 +106,16 @@ class FcOpConverter : public OpConverter {
...
@@ -106,14 +106,16 @@ class FcOpConverter : public OpConverter {
n_output
,
weight
.
get
(),
bias
.
get
());
n_output
,
weight
.
get
(),
bias
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
engine_
->
DeclareOutput
(
layer
,
0
,
output_name
);
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
}
}
};
};
REGISTER_TRT_OP_CONVERTER
(
fc
,
FcOpConverter
);
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
fc
,
FcOpConverter
);
USE_OP
(
mul
);
USE_OP
(
mul
);
paddle/fluid/inference/tensorrt/convert/mul_op.cc
浏览文件 @
c73977af
...
@@ -23,9 +23,8 @@ namespace tensorrt {
...
@@ -23,9 +23,8 @@ namespace tensorrt {
*/
*/
class
MulOpConverter
:
public
OpConverter
{
class
MulOpConverter
:
public
OpConverter
{
public:
public:
MulOpConverter
()
{}
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a fluid mul op to tensorrt mul layer without bias"
;
VLOG
(
4
)
<<
"convert a fluid mul op to tensorrt mul layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
...
@@ -37,12 +36,18 @@ class MulOpConverter : public OpConverter {
...
@@ -37,12 +36,18 @@ class MulOpConverter : public OpConverter {
engine_
,
MatrixMultiply
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
false
,
engine_
,
MatrixMultiply
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
false
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input2
),
false
);
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input2
),
false
);
engine_
->
DeclareOutput
(
layer
,
0
,
op_desc
.
Output
(
"Out"
)[
0
]);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
engine_
->
DeclareOutput
(
output_name
);
}
}
}
};
};
REGISTER_TRT_OP_CONVERTER
(
mul
,
MulOpConverter
);
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
USE_OP
(
mul
);
REGISTER_TRT_OP_CONVERTER
(
mul
,
MulOpConverter
);
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
c73977af
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
...
@@ -34,12 +35,15 @@ class OpConverter {
...
@@ -34,12 +35,15 @@ class OpConverter {
// Converter logic for an op.
// Converter logic for an op.
virtual
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
virtual
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
)
{}
const
framework
::
Scope
&
scope
,
bool
test_mode
=
false
)
{}
// Convert a single fluid operaotr and add the corresponding layer to TRT.
// Convert a single fluid operator and add the corresponding layer to TRT.
// test_mode: whether the instance executes in an unit test.
void
ConvertOp
(
const
framework
::
proto
::
OpDesc
&
op
,
void
ConvertOp
(
const
framework
::
proto
::
OpDesc
&
op
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
)
{
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
,
bool
test_mode
=
false
)
{
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
OpConverter
*
it
{
nullptr
};
OpConverter
*
it
{
nullptr
};
...
@@ -57,7 +61,7 @@ class OpConverter {
...
@@ -57,7 +61,7 @@ class OpConverter {
PADDLE_ENFORCE_NOT_NULL
(
it
,
"no OpConverter for optype [%s]"
,
PADDLE_ENFORCE_NOT_NULL
(
it
,
"no OpConverter for optype [%s]"
,
op_desc
.
Type
());
op_desc
.
Type
());
it
->
SetEngine
(
engine
);
it
->
SetEngine
(
engine
);
(
*
it
)(
op
,
scope
);
(
*
it
)(
op
,
scope
,
test_mode
);
}
}
// convert fluid block to tensorrt network
// convert fluid block to tensorrt network
...
@@ -77,6 +81,9 @@ class OpConverter {
...
@@ -77,6 +81,9 @@ class OpConverter {
// TensorRT engine
// TensorRT engine
TensorRTEngine
*
engine_
{
nullptr
};
TensorRTEngine
*
engine_
{
nullptr
};
protected:
bool
test_mode_
;
private:
private:
// registered op converter map, whose key is the fluid op type, and value is
// registered op converter map, whose key is the fluid op type, and value is
// the pointer position of corresponding OpConverter class.
// the pointer position of corresponding OpConverter class.
...
@@ -86,12 +93,23 @@ class OpConverter {
...
@@ -86,12 +93,23 @@ class OpConverter {
};
};
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter
{
\
struct trt_##op_type__##_converter
: public ::paddle::framework::Registrar {
\
trt_##op_type__##_converter() { \
trt_##op_type__##_converter() { \
Registry<OpConverter>::Register<Converter__>(#op_type__); \
::paddle::inference:: \
Registry<paddle::inference::tensorrt::OpConverter>::Register< \
::paddle::inference::tensorrt::Converter__>(#op_type__); \
} \
} \
}; \
}; \
trt_##op_type__##_converter trt_##op_type__##_converter__;
trt_##op_type__##_converter trt_##op_type__##_converter__; \
int TouchConverterRegister_##op_type__() { \
trt_##op_type__##_converter__.Touch(); \
return 0; \
}
#define USE_TRT_CONVERTER(op_type__) \
extern int TouchConverterRegister_##op_type__(); \
static int use_op_converter_trt_##op_type__ __attribute__((unused)) = \
TouchConverterRegister_##op_type__();
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
浏览文件 @
c73977af
...
@@ -36,3 +36,5 @@ TEST(OpConverter, ConvertBlock) {
...
@@ -36,3 +36,5 @@ TEST(OpConverter, ConvertBlock) {
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
USE_TRT_CONVERTER
(
conv2d
)
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
c73977af
...
@@ -27,6 +27,7 @@ limitations under the License. */
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -104,8 +105,8 @@ class TRTConvertValidation {
...
@@ -104,8 +105,8 @@ class TRTConvertValidation {
void
SetOp
(
const
framework
::
proto
::
OpDesc
&
desc
)
{
void
SetOp
(
const
framework
::
proto
::
OpDesc
&
desc
)
{
op_
=
framework
::
OpRegistry
::
CreateOp
(
desc
);
op_
=
framework
::
OpRegistry
::
CreateOp
(
desc
);
OpConverter
op_converter
;
Singleton
<
OpConverter
>::
Global
().
ConvertOp
(
op_converter
.
ConvertOp
(
desc
,
parameters_
,
scope_
,
engine_
.
get
()
);
desc
,
parameters_
,
scope_
,
engine_
.
get
(),
true
/*test_mode*/
);
engine_
->
FreezeNetwork
();
engine_
->
FreezeNetwork
();
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
c73977af
...
@@ -43,9 +43,10 @@ void TensorRTEngine::Execute(int batch_size) {
...
@@ -43,9 +43,10 @@ void TensorRTEngine::Execute(int batch_size) {
}
}
TensorRTEngine
::~
TensorRTEngine
()
{
TensorRTEngine
::~
TensorRTEngine
()
{
cudaStreamSynchronize
(
*
stream_
);
// clean buffer
// clean buffer
for
(
auto
&
buf
:
buffers_
)
{
for
(
auto
&
buf
:
buffers_
)
{
if
(
buf
.
buffer
!=
nullptr
)
{
if
(
buf
.
device
==
DeviceType
::
GPU
&&
buf
.
buffer
!=
nullptr
)
{
PADDLE_ENFORCE_EQ
(
0
,
cudaFree
(
buf
.
buffer
));
PADDLE_ENFORCE_EQ
(
0
,
cudaFree
(
buf
.
buffer
));
buf
.
buffer
=
nullptr
;
buf
.
buffer
=
nullptr
;
buf
.
max_size
=
0
;
buf
.
max_size
=
0
;
...
@@ -80,6 +81,8 @@ void TensorRTEngine::FreezeNetwork() {
...
@@ -80,6 +81,8 @@ void TensorRTEngine::FreezeNetwork() {
auto
&
buf
=
buffer
(
item
.
first
);
auto
&
buf
=
buffer
(
item
.
first
);
CHECK
(
buf
.
buffer
==
nullptr
);
// buffer should be allocated only once.
CHECK
(
buf
.
buffer
==
nullptr
);
// buffer should be allocated only once.
PADDLE_ENFORCE_EQ
(
0
,
cudaMalloc
(
&
buf
.
buffer
,
item
.
second
));
PADDLE_ENFORCE_EQ
(
0
,
cudaMalloc
(
&
buf
.
buffer
,
item
.
second
));
VLOG
(
4
)
<<
"buffer malloc "
<<
item
.
first
<<
" "
<<
item
.
second
<<
" "
<<
buf
.
buffer
;
buf
.
size
=
buf
.
max_size
=
item
.
second
;
buf
.
size
=
buf
.
max_size
=
item
.
second
;
buf
.
device
=
DeviceType
::
GPU
;
buf
.
device
=
DeviceType
::
GPU
;
}
}
...
@@ -96,6 +99,7 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
...
@@ -96,6 +99,7 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
PADDLE_ENFORCE
(
input
,
"infer network add input %s failed"
,
name
);
PADDLE_ENFORCE
(
input
,
"infer network add input %s failed"
,
name
);
buffer_sizes_
[
name
]
=
kDataTypeSize
[
static_cast
<
int
>
(
dtype
)]
*
buffer_sizes_
[
name
]
=
kDataTypeSize
[
static_cast
<
int
>
(
dtype
)]
*
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
);
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
);
PADDLE_ENFORCE
(
input
->
isNetworkInput
());
TensorRTEngine
::
SetITensor
(
name
,
input
);
TensorRTEngine
::
SetITensor
(
name
,
input
);
return
input
;
return
input
;
}
}
...
@@ -109,7 +113,9 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
...
@@ -109,7 +113,9 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
SetITensor
(
name
,
output
);
SetITensor
(
name
,
output
);
PADDLE_ENFORCE
(
output
!=
nullptr
);
PADDLE_ENFORCE
(
output
!=
nullptr
);
output
->
setName
(
name
.
c_str
());
output
->
setName
(
name
.
c_str
());
PADDLE_ENFORCE
(
!
output
->
isNetworkInput
());
infer_network_
->
markOutput
(
*
output
);
infer_network_
->
markOutput
(
*
output
);
PADDLE_ENFORCE
(
output
->
isNetworkOutput
());
// output buffers' size can only be decided latter, set zero here to mark this
// output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter.
// and will reset latter.
buffer_sizes_
[
name
]
=
0
;
buffer_sizes_
[
name
]
=
0
;
...
@@ -122,6 +128,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {
...
@@ -122,6 +128,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
PADDLE_ENFORCE
(
output
!=
nullptr
);
PADDLE_ENFORCE
(
output
!=
nullptr
);
output
->
setName
(
name
.
c_str
());
output
->
setName
(
name
.
c_str
());
PADDLE_ENFORCE
(
!
output
->
isNetworkInput
());
infer_network_
->
markOutput
(
*
output
);
infer_network_
->
markOutput
(
*
output
);
// output buffers' size can only be decided latter, set zero here to mark this
// output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter.
// and will reset latter.
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
c73977af
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -131,7 +132,11 @@ class TensorRTEngine : public EngineBase {
...
@@ -131,7 +132,11 @@ class TensorRTEngine : public EngineBase {
// TensorRT related internal members
// TensorRT related internal members
template
<
typename
T
>
template
<
typename
T
>
struct
Destroyer
{
struct
Destroyer
{
void
operator
()(
T
*
x
)
{
x
->
destroy
();
}
void
operator
()(
T
*
x
)
{
if
(
x
)
{
x
->
destroy
();
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
using
infer_ptr
=
std
::
unique_ptr
<
T
,
Destroyer
<
T
>>
;
using
infer_ptr
=
std
::
unique_ptr
<
T
,
Destroyer
<
T
>>
;
...
@@ -155,6 +160,27 @@ class TensorRTEngine : public EngineBase {
...
@@ -155,6 +160,27 @@ class TensorRTEngine : public EngineBase {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS);
engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class
TRT_EngineManager
{
public:
TensorRTEngine
*
Create
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
)
{
engines_
.
emplace_back
(
new
TensorRTEngine
(
max_batch
,
max_workspace
,
stream
));
return
engines_
.
back
().
get
();
}
void
DeleteALl
()
{
for
(
auto
&
ptr
:
engines_
)
{
ptr
.
reset
(
nullptr
);
}
}
private:
std
::
vector
<
std
::
unique_ptr
<
TensorRTEngine
>>
engines_
;
};
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/tests/book/test_inference_nlp.cc
浏览文件 @
c73977af
...
@@ -101,23 +101,22 @@ void SplitData(
...
@@ -101,23 +101,22 @@ void SplitData(
}
}
void
ThreadRunInfer
(
void
ThreadRunInfer
(
const
int
tid
,
paddle
::
framework
::
Executor
*
executor
,
const
int
tid
,
paddle
::
framework
::
Scope
*
scope
,
paddle
::
framework
::
Scope
*
scope
,
const
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>&
inference_program
,
const
std
::
vector
<
std
::
vector
<
const
paddle
::
framework
::
LoDTensor
*>>&
jobs
)
{
const
std
::
vector
<
std
::
vector
<
const
paddle
::
framework
::
LoDTensor
*>>&
jobs
)
{
auto
copy_program
=
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
(
// maybe framework:ProgramDesc is not thread-safe
new
paddle
::
framework
::
ProgramDesc
(
*
inference_program
));
auto
&
sub_scope
=
scope
->
NewScope
();
auto
&
sub_scope
=
scope
->
NewScope
();
auto
place
=
paddle
::
platform
::
CPUPlace
();
auto
executor
=
paddle
::
framework
::
Executor
(
place
);
auto
inference_program
=
paddle
::
inference
::
Load
(
&
executor
,
scope
,
FLAGS_model_path
);
std
::
string
feed_holder_name
=
"feed_"
+
paddle
::
string
::
to_string
(
tid
);
auto
ctx
=
executor
.
Prepare
(
*
inference_program
,
/*block_id*/
0
);
std
::
string
fetch_holder_name
=
"fetch_"
+
paddle
::
string
::
to_string
(
tid
);
executor
.
CreateVariables
(
*
inference_program
,
&
sub_scope
,
/*block_id*/
0
);
copy_program
->
SetFeedHolderName
(
feed_holder_name
);
copy_program
->
SetFetchHolderName
(
fetch_holder_name
);
const
std
::
vector
<
std
::
string
>&
feed_target_names
=
const
std
::
vector
<
std
::
string
>&
feed_target_names
=
copy
_program
->
GetFeedTargetNames
();
inference
_program
->
GetFeedTargetNames
();
const
std
::
vector
<
std
::
string
>&
fetch_target_names
=
const
std
::
vector
<
std
::
string
>&
fetch_target_names
=
copy
_program
->
GetFetchTargetNames
();
inference
_program
->
GetFetchTargetNames
();
PADDLE_ENFORCE_EQ
(
fetch_target_names
.
size
(),
1UL
);
PADDLE_ENFORCE_EQ
(
fetch_target_names
.
size
(),
1UL
);
std
::
map
<
std
::
string
,
paddle
::
framework
::
LoDTensor
*>
fetch_targets
;
std
::
map
<
std
::
string
,
paddle
::
framework
::
LoDTensor
*>
fetch_targets
;
...
@@ -131,9 +130,8 @@ void ThreadRunInfer(
...
@@ -131,9 +130,8 @@ void ThreadRunInfer(
auto
start_ms
=
GetCurrentMs
();
auto
start_ms
=
GetCurrentMs
();
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
feed_targets
[
feed_target_names
[
0
]]
=
inputs
[
i
];
feed_targets
[
feed_target_names
[
0
]]
=
inputs
[
i
];
executor
->
Run
(
*
copy_program
,
&
sub_scope
,
&
feed_targets
,
&
fetch_targets
,
executor
.
RunPreparedContext
(
ctx
.
get
(),
&
sub_scope
,
&
feed_targets
,
true
/*create_local_scope*/
,
true
/*create_vars*/
,
&
fetch_targets
,
false
/*create_local_scope*/
);
feed_holder_name
,
fetch_holder_name
);
}
}
auto
stop_ms
=
GetCurrentMs
();
auto
stop_ms
=
GetCurrentMs
();
scope
->
DeleteScope
(
&
sub_scope
);
scope
->
DeleteScope
(
&
sub_scope
);
...
@@ -158,22 +156,10 @@ TEST(inference, nlp) {
...
@@ -158,22 +156,10 @@ TEST(inference, nlp) {
LOG
(
INFO
)
<<
"Number of samples (seq_len<1024): "
<<
datasets
.
size
();
LOG
(
INFO
)
<<
"Number of samples (seq_len<1024): "
<<
datasets
.
size
();
LOG
(
INFO
)
<<
"Total number of words: "
<<
num_total_words
;
LOG
(
INFO
)
<<
"Total number of words: "
<<
num_total_words
;
const
bool
model_combined
=
false
;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// 1. Define place, executor, scope
auto
place
=
paddle
::
platform
::
CPUPlace
();
auto
executor
=
paddle
::
framework
::
Executor
(
place
);
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
scope
(
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
scope
(
new
paddle
::
framework
::
Scope
());
new
paddle
::
framework
::
Scope
());
// 2. Initialize the inference_program and load parameters
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
inference_program
;
inference_program
=
InitProgram
(
&
executor
,
scope
.
get
(),
FLAGS_model_path
,
model_combined
);
if
(
FLAGS_use_mkldnn
)
{
EnableMKLDNN
(
inference_program
);
}
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
// only use 1 thread number per std::thread
// only use 1 thread number per std::thread
omp_set_dynamic
(
0
);
omp_set_dynamic
(
0
);
...
@@ -189,21 +175,30 @@ TEST(inference, nlp) {
...
@@ -189,21 +175,30 @@ TEST(inference, nlp) {
start_ms
=
GetCurrentMs
();
start_ms
=
GetCurrentMs
();
for
(
int
i
=
0
;
i
<
FLAGS_num_threads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
FLAGS_num_threads
;
++
i
)
{
threads
.
emplace_back
(
threads
.
emplace_back
(
new
std
::
thread
(
ThreadRunInfer
,
i
,
&
executor
,
scope
.
get
(),
new
std
::
thread
(
ThreadRunInfer
,
i
,
scope
.
get
(),
std
::
ref
(
jobs
)));
std
::
ref
(
inference_program
),
std
::
ref
(
jobs
)));
}
}
for
(
int
i
=
0
;
i
<
FLAGS_num_threads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
FLAGS_num_threads
;
++
i
)
{
threads
[
i
]
->
join
();
threads
[
i
]
->
join
();
}
}
stop_ms
=
GetCurrentMs
();
stop_ms
=
GetCurrentMs
();
}
else
{
}
else
{
if
(
FLAGS_prepare_vars
)
{
// 1. Define place, executor, scope
executor
.
CreateVariables
(
*
inference_program
,
scope
.
get
(),
0
);
auto
place
=
paddle
::
platform
::
CPUPlace
();
auto
executor
=
paddle
::
framework
::
Executor
(
place
);
// 2. Initialize the inference_program and load parameters
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
inference_program
;
inference_program
=
InitProgram
(
&
executor
,
scope
.
get
(),
FLAGS_model_path
,
/*model combined*/
false
);
if
(
FLAGS_use_mkldnn
)
{
EnableMKLDNN
(
inference_program
);
}
}
// always prepare context
// always prepare context
std
::
unique_ptr
<
paddle
::
framework
::
ExecutorPrepareContext
>
ctx
;
std
::
unique_ptr
<
paddle
::
framework
::
ExecutorPrepareContext
>
ctx
;
ctx
=
executor
.
Prepare
(
*
inference_program
,
0
);
ctx
=
executor
.
Prepare
(
*
inference_program
,
0
);
if
(
FLAGS_prepare_vars
)
{
executor
.
CreateVariables
(
*
inference_program
,
scope
.
get
(),
0
);
}
// preapre fetch
// preapre fetch
const
std
::
vector
<
std
::
string
>&
fetch_target_names
=
const
std
::
vector
<
std
::
string
>&
fetch_target_names
=
inference_program
->
GetFetchTargetNames
();
inference_program
->
GetFetchTargetNames
();
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
c73977af
...
@@ -227,6 +227,8 @@ op_library(softmax_op DEPS softmax)
...
@@ -227,6 +227,8 @@ op_library(softmax_op DEPS softmax)
op_library
(
sequence_softmax_op DEPS softmax
)
op_library
(
sequence_softmax_op DEPS softmax
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
op_library
(
tensorrt_engine_op DEPS tensorrt_engine
)
op_library
(
tensorrt_engine_op DEPS tensorrt_engine
)
nv_test
(
test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter
)
else
()
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
tensorrt_engine_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
tensorrt_engine_op
)
endif
()
endif
()
...
...
paddle/fluid/operators/tensorrt_engine_op.cc
浏览文件 @
c73977af
...
@@ -17,23 +17,93 @@
...
@@ -17,23 +17,93 @@
#include "paddle/fluid/operators/tensorrt_engine_op.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
inference
::
Singleton
;
using
inference
::
tensorrt
::
TRT_EngineManager
;
using
FluidDT
=
framework
::
proto
::
VarType_Type
;
using
TRT_DT
=
nvinfer1
::
DataType
;
namespace
{
TRT_DT
FluidDataType2TRT
(
FluidDT
type
)
{
switch
(
type
)
{
case
FluidDT
::
VarType_Type_FP32
:
return
TRT_DT
::
kFLOAT
;
case
FluidDT
::
VarType_Type_INT32
:
return
TRT_DT
::
kINT32
;
default:
return
TRT_DT
::
kINT32
;
}
PADDLE_THROW
(
"unkown type"
);
return
TRT_DT
::
kINT32
;
}
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
int64_t
>
&
shape
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1UL
,
"TensorRT' tensor input requires at least 2 dimensions"
);
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
"TensorRT' tensor input requires at most 4 dimensions"
);
switch
(
shape
.
size
())
{
case
2
:
return
nvinfer1
::
Dims2
(
shape
[
0
],
shape
[
1
]);
case
3
:
return
nvinfer1
::
Dims3
(
shape
[
0
],
shape
[
1
],
shape
[
2
]);
case
4
:
return
nvinfer1
::
Dims4
(
shape
[
0
],
shape
[
1
],
shape
[
2
],
shape
[
3
]);
default:
return
nvinfer1
::
Dims
();
}
return
nvinfer1
::
Dims
();
}
}
// namespace
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
void
paddle
::
operators
::
TensorRTEngineKernel
<
DeviceContext
,
T
>::
Prepare
(
void
paddle
::
operators
::
TensorRTEngineKernel
<
DeviceContext
,
T
>::
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
{
const
framework
::
ExecutionContext
&
context
)
const
{
VLOG
(
4
)
<<
"Prepare engine"
;
// Get the ProgramDesc and pass to convert.
// Get the ProgramDesc and pass to convert.
const
auto
&
block
=
context
.
Attr
<
framework
::
proto
::
BlockDesc
>
(
"subgraph"
);
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
context
.
Attr
<
std
::
string
>
(
"subgraph"
));
max_batch_
=
context
.
Attr
<
int
>
(
"max_batch"
);
max_batch_
=
context
.
Attr
<
int
>
(
"max_batch"
);
auto
max_workspace
=
context
.
Attr
<
int
>
(
"max_workspace"
);
auto
max_workspace
=
context
.
Attr
<
int
>
(
"max_workspace"
);
engine_
.
reset
(
new
inference
::
tensorrt
::
TensorRTEngine
(
engine_
=
Singleton
<
TRT_EngineManager
>::
Global
().
Create
(
max_batch_
,
max_workspace
,
nullptr
));
max_batch_
,
max_workspace
,
&
stream_
);
engine_
->
InitNetwork
();
framework
::
BlockDesc
block
(
nullptr
/*programdesc*/
,
&
block_desc
);
// Add inputs
VLOG
(
4
)
<<
"declare inputs"
;
for
(
auto
&
input
:
context
.
Inputs
(
"Xs"
))
{
VLOG
(
4
)
<<
"declare input "
<<
input
;
auto
*
var
=
block
.
FindVar
(
input
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
"TensorRT engine only takes LoDTensor as input"
);
auto
shape
=
var
->
GetShape
();
engine_
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
Vec2TRT_Dims
(
var
->
GetShape
()));
}
// TODO(Superjomn) parameters should be passed after analysised from outside.
// TODO(Superjomn) parameters should be passed after analysised from outside.
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
().
ConvertBlock
(
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
().
ConvertBlock
(
block
,
{},
context
.
scope
(),
engine_
.
get
());
block_desc
,
{},
context
.
scope
(),
engine_
);
// Add outputs
VLOG
(
4
)
<<
"declare outputs"
;
for
(
auto
&
output
:
context
.
Outputs
(
"Ys"
))
{
VLOG
(
4
)
<<
"declare output "
<<
output
;
engine_
->
DeclareOutput
(
output
);
}
engine_
->
FreezeNetwork
();
engine_
->
FreezeNetwork
();
}
}
...
@@ -42,7 +112,9 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -42,7 +112,9 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"Xs"
,
"A list of inputs."
).
AsDuplicable
();
AddInput
(
"Xs"
,
"A list of inputs."
).
AsDuplicable
();
AddOutput
(
"Ys"
,
"A list of outputs"
).
AsDuplicable
();
AddOutput
(
"Ys"
,
"A list of outputs"
).
AsDuplicable
();
AddAttr
<
std
::
string
>
(
"subgraph"
,
"the subgraph"
);
AddAttr
<
std
::
string
>
(
"subgraph"
,
"the subgraph."
);
AddAttr
<
int
>
(
"max_batch"
,
"the maximum batch size."
);
AddAttr
<
int
>
(
"max_workspace"
,
"the maximum batch size."
);
AddComment
(
"TensorRT engine operator."
);
AddComment
(
"TensorRT engine operator."
);
}
}
};
};
...
...
paddle/fluid/operators/tensorrt_engine_op.h
浏览文件 @
c73977af
...
@@ -32,9 +32,12 @@ class TensorRTEngineOp : public framework::OperatorWithKernel {
...
@@ -32,9 +32,12 @@ class TensorRTEngineOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input0
=
ctx
.
Inputs
(
"Xs"
).
front
();
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
ToDataType
(
framework
::
ToDataType
(
ctx
.
scope
()
ctx
.
Input
<
framework
::
LoDTensor
>
(
"pre_ids"
)
->
type
()),
.
FindVar
(
input0
)
->
GetMutable
<
framework
::
LoDTensor
>
()
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
return
kt
;
return
kt
;
}
}
...
@@ -50,17 +53,16 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
...
@@ -50,17 +53,16 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto
input_names
=
context
.
op
().
Inputs
(
"Xs"
);
auto
input_names
=
context
.
op
().
Inputs
(
"Xs"
);
PADDLE_ENFORCE
(
!
input_names
.
empty
(),
"should pass more than one inputs"
);
PADDLE_ENFORCE
(
!
input_names
.
empty
(),
"should pass more than one inputs"
);
// Try to determine a batch_size
// Try to determine a batch_size
auto
*
tensor0
=
context
.
Input
<
framework
::
LoDTensor
>
(
input_names
.
front
());
auto
&
tensor0
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
PADDLE_ENFORCE_NOT_NULL
(
tensor0
);
context
.
scope
(),
input_names
.
front
()
);
int
batch_size
=
tensor0
->
dims
()[
0
];
int
batch_size
=
tensor0
.
dims
()[
0
];
PADDLE_ENFORCE_LE
(
batch_size
,
max_batch_
);
PADDLE_ENFORCE_LE
(
batch_size
,
max_batch_
);
// Convert input tensor from fluid to engine.
// Convert input tensor from fluid to engine.
for
(
const
auto
&
x
:
context
.
Inputs
(
"Xs"
))
{
for
(
const
auto
&
x
:
context
.
Inputs
(
"Xs"
))
{
// convert input and copy to TRT engine's buffer
// convert input and copy to TRT engine's buffer
auto
*
v
=
context
.
scope
().
FindVar
(
x
);
auto
&
t
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
PADDLE_ENFORCE_NOT_NULL
(
v
,
"no variable called %s"
,
x
);
context
.
scope
(),
x
);
auto
&
t
=
v
->
Get
<
framework
::
LoDTensor
>
();
if
(
platform
::
is_cpu_place
(
t
.
place
()))
{
if
(
platform
::
is_cpu_place
(
t
.
place
()))
{
engine_
->
SetInputFromCPU
(
x
,
static_cast
<
const
void
*>
(
t
.
data
<
void
>
()),
engine_
->
SetInputFromCPU
(
x
,
static_cast
<
const
void
*>
(
t
.
data
<
void
>
()),
t
.
memory_size
());
t
.
memory_size
());
...
@@ -86,13 +88,18 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
...
@@ -86,13 +88,18 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
fluid_t
->
Resize
(
framework
::
make_ddim
(
ddim
));
fluid_t
->
Resize
(
framework
::
make_ddim
(
ddim
));
auto
size
=
inference
::
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
);
auto
size
=
inference
::
analysis
::
AccuDims
(
dims
.
d
,
dims
.
nbDims
);
if
(
platform
::
is_cpu_place
(
fluid_t
->
place
()))
{
if
(
platform
::
is_cpu_place
(
fluid_t
->
place
()))
{
// TODO(Superjomn) change this float to dtype size.
engine_
->
GetOutputInCPU
(
engine_
->
GetOutputInCPU
(
y
,
fluid_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
size
);
y
,
fluid_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
size
*
sizeof
(
float
));
}
else
{
}
else
{
engine_
->
GetOutputInGPU
(
engine_
->
GetOutputInGPU
(
y
,
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
()),
size
);
y
,
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
()),
size
*
sizeof
(
float
));
}
}
}
}
cudaStreamSynchronize
(
stream_
);
}
}
protected:
protected:
...
@@ -100,7 +107,8 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
...
@@ -100,7 +107,8 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
void
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
;
void
Prepare
(
const
framework
::
ExecutionContext
&
context
)
const
;
private:
private:
mutable
std
::
unique_ptr
<
inference
::
tensorrt
::
TensorRTEngine
>
engine_
;
mutable
cudaStream_t
stream_
;
mutable
inference
::
tensorrt
::
TensorRTEngine
*
engine_
{
nullptr
};
mutable
int
max_batch_
{
0
};
mutable
int
max_batch_
{
0
};
};
};
...
...
paddle/fluid/operators/tensorrt_engine_op_test.cc
0 → 100644
浏览文件 @
c73977af
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_CPU_ONLY_OP
(
tensorrt_engine
);
namespace
paddle
{
namespace
operators
{
namespace
{
void
CreateCPUTensor
(
framework
::
Scope
*
scope
,
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
shape
)
{
auto
*
var
=
scope
->
Var
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dims
=
framework
::
make_ddim
(
shape
);
tensor
->
Resize
(
dims
);
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
inference
::
tensorrt
::
RandomizeTensor
(
tensor
,
place
,
ctx
);
}
void
AddTensorToBlockDesc
(
framework
::
proto
::
BlockDesc
*
block
,
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
shape
)
{
using
framework
::
proto
::
VarType
;
auto
*
var
=
block
->
add_vars
();
framework
::
VarDesc
desc
(
name
);
desc
.
SetType
(
VarType
::
LOD_TENSOR
);
desc
.
SetDataType
(
VarType
::
FP32
);
desc
.
SetShape
(
shape
);
*
var
=
*
desc
.
Proto
();
}
template
<
typename
T
>
void
SetAttr
(
framework
::
proto
::
OpDesc
*
op
,
const
std
::
string
&
name
,
const
T
&
data
);
template
<
>
void
SetAttr
<
std
::
string
>
(
framework
::
proto
::
OpDesc
*
op
,
const
std
::
string
&
name
,
const
std
::
string
&
data
)
{
auto
*
attr
=
op
->
add_attrs
();
attr
->
set_name
(
name
);
attr
->
set_type
(
paddle
::
framework
::
proto
::
AttrType
::
STRING
);
attr
->
set_s
(
data
);
}
template
<
>
void
SetAttr
<
int
>
(
framework
::
proto
::
OpDesc
*
op
,
const
std
::
string
&
name
,
const
int
&
data
)
{
auto
*
attr
=
op
->
add_attrs
();
attr
->
set_name
(
name
);
attr
->
set_type
(
paddle
::
framework
::
proto
::
AttrType
::
INT
);
attr
->
set_i
(
data
);
}
template
<
>
void
SetAttr
<
int64_t
>
(
framework
::
proto
::
OpDesc
*
op
,
const
std
::
string
&
name
,
const
int64_t
&
data
)
{
auto
*
attr
=
op
->
add_attrs
();
attr
->
set_name
(
name
);
attr
->
set_type
(
paddle
::
framework
::
proto
::
AttrType
::
LONG
);
attr
->
set_l
(
data
);
}
}
// namespace
TEST
(
TensorRTEngineOp
,
manual
)
{
framework
::
ProgramDesc
program
;
auto
*
block_
=
program
.
Proto
()
->
add_blocks
();
block_
->
set_idx
(
0
);
block_
->
set_parent_idx
(
-
1
);
LOG
(
INFO
)
<<
"create block desc"
;
framework
::
BlockDesc
block_desc
(
&
program
,
block_
);
LOG
(
INFO
)
<<
"create mul op"
;
auto
*
mul
=
block_desc
.
AppendOp
();
mul
->
SetType
(
"mul"
);
mul
->
SetInput
(
"X"
,
std
::
vector
<
std
::
string
>
({
"x"
}));
// 2 x 4
mul
->
SetInput
(
"Y"
,
std
::
vector
<
std
::
string
>
({
"y"
}));
// 4 x 6
mul
->
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
({
"z"
}));
// 2 x 6
LOG
(
INFO
)
<<
"create fc op"
;
auto
*
fc
=
block_desc
.
AppendOp
();
fc
->
SetType
(
"mul"
);
fc
->
SetInput
(
"X"
,
std
::
vector
<
std
::
string
>
({
"z"
}));
fc
->
SetInput
(
"Y"
,
std
::
vector
<
std
::
string
>
({
"y0"
}));
// 6 x 8
fc
->
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
({
"z0"
}));
// 2 x 8
// Set inputs' variable shape in BlockDesc
AddTensorToBlockDesc
(
block_
,
"x"
,
std
::
vector
<
int64_t
>
({
2
,
4
}));
AddTensorToBlockDesc
(
block_
,
"y"
,
std
::
vector
<
int64_t
>
({
4
,
6
}));
AddTensorToBlockDesc
(
block_
,
"y0"
,
std
::
vector
<
int64_t
>
({
6
,
8
}));
AddTensorToBlockDesc
(
block_
,
"z"
,
std
::
vector
<
int64_t
>
({
2
,
6
}));
// It is wired, need to copy manually.
*
block_
->
add_ops
()
=
*
mul
->
Proto
();
*
block_
->
add_ops
()
=
*
fc
->
Proto
();
ASSERT_EQ
(
block_
->
ops_size
(),
2
);
LOG
(
INFO
)
<<
"create tensorrt desc"
;
framework
::
OpDesc
engine_op_desc
(
nullptr
);
engine_op_desc
.
SetType
(
"tensorrt_engine"
);
engine_op_desc
.
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
({
"x"
,
"y"
,
"y0"
}));
engine_op_desc
.
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
({
"z0"
}));
SetAttr
<
std
::
string
>
(
engine_op_desc
.
Proto
(),
"subgraph"
,
block_
->
SerializeAsString
());
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"max_batch"
,
30
);
SetAttr
<
int
>
(
engine_op_desc
.
Proto
(),
"max_workspace"
,
1
<<
10
);
LOG
(
INFO
)
<<
"create engine op"
;
auto
engine_op
=
framework
::
OpRegistry
::
CreateOp
(
*
engine_op_desc
.
Proto
());
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
// Prepare variables.
CreateCPUTensor
(
&
scope
,
"x"
,
std
::
vector
<
int64_t
>
({
2
,
4
}));
CreateCPUTensor
(
&
scope
,
"y"
,
std
::
vector
<
int64_t
>
({
4
,
6
}));
CreateCPUTensor
(
&
scope
,
"z"
,
std
::
vector
<
int64_t
>
({
2
,
6
}));
CreateCPUTensor
(
&
scope
,
"y0"
,
std
::
vector
<
int64_t
>
({
6
,
8
}));
CreateCPUTensor
(
&
scope
,
"z0"
,
std
::
vector
<
int64_t
>
({
2
,
8
}));
// Execute them.
LOG
(
INFO
)
<<
"engine_op run"
;
engine_op
->
Run
(
scope
,
place
);
}
}
// namespace operators
}
// namespace paddle
USE_TRT_CONVERTER
(
mul
)
USE_TRT_CONVERTER
(
fc
)
paddle/fluid/platform/assert.h
浏览文件 @
c73977af
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
#define STRINGIFY(x) #x
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
#define TOSTRING(x) STRINGIFY(x)
#if defined(__
APPLE__) && defined(__CUDA_ARCH__) && !defined(NDEBUG
)
#if defined(__
CUDA_ARCH__
)
#include <stdio.h>
#include <stdio.h>
#define PADDLE_ASSERT(e) \
#define PADDLE_ASSERT(e) \
do { \
do { \
...
@@ -38,6 +38,9 @@ limitations under the License. */
...
@@ -38,6 +38,9 @@ limitations under the License. */
} while (0)
} while (0)
#else
#else
#include <assert.h>
#include <assert.h>
#define PADDLE_ASSERT(e) assert(e)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ASSERT(e) assert((e))
#define PADDLE_ASSERT_MSG(e, m) assert((e) && (m))
#define PADDLE_ASSERT_MSG(e, m) assert((e) && (m))
#endif
#endif
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
c73977af
...
@@ -81,6 +81,27 @@ enum class PoolingMode {
...
@@ -81,6 +81,27 @@ enum class PoolingMode {
kMaximumDeterministic
,
kMaximumDeterministic
,
};
};
#if CUDNN_VERSION < 6000
#pragma message "CUDNN version under 6.0 is supported at best effort."
#pragma message "We strongly encourage you to move to 6.0 and above."
#pragma message "This message is intended to annoy you enough to update."
#pragma message \
"please see https://docs.nvidia.com/deeplearning/sdk/cudnn-release-notes/"
inline
cudnnPoolingMode_t
GetPoolingMode
(
const
PoolingMode
&
mode
)
{
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
return
CUDNN_POOLING_MAX
;
case
PoolingMode
::
kAverage
:
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
case
PoolingMode
::
kMaximum
:
return
CUDNN_POOLING_MAX
;
default:
PADDLE_THROW
(
"Unexpected pooling mode."
);
}
}
#else
inline
cudnnPoolingMode_t
GetPoolingMode
(
const
PoolingMode
&
mode
)
{
inline
cudnnPoolingMode_t
GetPoolingMode
(
const
PoolingMode
&
mode
)
{
switch
(
mode
)
{
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
case
PoolingMode
::
kMaximumDeterministic
:
...
@@ -93,6 +114,7 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
...
@@ -93,6 +114,7 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
PADDLE_THROW
(
"Unexpected pooling mode."
);
PADDLE_THROW
(
"Unexpected pooling mode."
);
}
}
}
}
#endif // CUDNN_VERSION < 6000
template
<
typename
T
>
template
<
typename
T
>
class
CudnnDataType
;
class
CudnnDataType
;
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
c73977af
...
@@ -447,7 +447,7 @@ EOF
...
@@ -447,7 +447,7 @@ EOF
# run paddle version to install python packages first
# run paddle version to install python packages first
RUN apt-get update &&
\
RUN apt-get update &&
\
${
NCCL_DEPS
}
\
${
NCCL_DEPS
}
\
apt-get install -y wget python-pip dmidecode python-tk && easy_install -U pip &&
\
apt-get install -y wget python-pip
python-opencv
dmidecode python-tk && easy_install -U pip &&
\
pip install /*.whl; apt-get install -f -y &&
\
pip install /*.whl; apt-get install -f -y &&
\
apt-get clean -y &&
\
apt-get clean -y &&
\
rm -f /*.whl &&
\
rm -f /*.whl &&
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录