Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8c171902
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c171902
编写于
2月 14, 2019
作者:
N
nhzlx
提交者:
ceci3
3月 08, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
2. TRTEngine using stream only when execute.
上级
88c24baa
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
31 addition
and
73 deletion
+31
-73
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+2
-4
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+4
-29
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+8
-13
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+5
-5
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+12
-22
未找到文件。
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
8c171902
...
@@ -79,7 +79,7 @@ class TRTConvertValidation {
...
@@ -79,7 +79,7 @@ class TRTConvertValidation {
if_add_batch_
(
if_add_batch
),
if_add_batch_
(
if_add_batch
),
max_batch_size_
(
max_batch_size
)
{
max_batch_size_
(
max_batch_size
)
{
PADDLE_ENFORCE_EQ
(
cudaStreamCreate
(
&
stream_
),
0
);
PADDLE_ENFORCE_EQ
(
cudaStreamCreate
(
&
stream_
),
0
);
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size
,
workspace_size
,
stream_
));
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size
,
workspace_size
));
engine_
->
InitNetwork
();
engine_
->
InitNetwork
();
}
}
...
@@ -192,9 +192,7 @@ class TRTConvertValidation {
...
@@ -192,9 +192,7 @@ class TRTConvertValidation {
}
}
// Execute TRT.
// Execute TRT.
engine_
->
Execute
(
batch_size
,
buffers
);
engine_
->
Execute
(
batch_size
,
&
buffers
,
stream_
);
cudaStreamSynchronize
(
engine_
->
stream
());
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
int
index
=
0
;
int
index
=
0
;
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
8c171902
...
@@ -32,39 +32,14 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
...
@@ -32,39 +32,14 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
PADDLE_ENFORCE
(
false
,
"not implemented"
);
PADDLE_ENFORCE
(
false
,
"not implemented"
);
}
}
void
TensorRTEngine
::
Execute
(
int
batch_size
,
std
::
vector
<
void
*>
&
buffers
)
{
void
TensorRTEngine
::
Execute
(
int
batch_size
,
std
::
vector
<
void
*>
*
buffers
,
cudaStream_t
stream
)
{
batch_size_
=
batch_size
;
batch_size_
=
batch_size
;
infer_context_
->
enqueue
(
batch_size
,
buffers
.
data
(),
stream_
,
nullptr
);
infer_context_
->
enqueue
(
batch_size
,
buffers
->
data
(),
stream
,
nullptr
);
cudaStreamSynchronize
(
stream
_
);
cudaStreamSynchronize
(
stream
);
SetRuntimeBatch
(
batch_size
);
SetRuntimeBatch
(
batch_size
);
}
}
void
TensorRTEngine
::
Execute
(
int
batch_size
)
{
batch_size_
=
batch_size
;
std
::
vector
<
void
*>
buffers
;
for
(
auto
&
buf
:
buffers_
)
{
PADDLE_ENFORCE_NOT_NULL
(
buf
.
buffer
,
"buffer should be allocated"
);
PADDLE_ENFORCE_GT
(
buf
.
max_size
,
0
);
PADDLE_ENFORCE
(
buf
.
device
==
DeviceType
::
GPU
);
buffers
.
push_back
(
buf
.
buffer
);
}
infer_context_
->
enqueue
(
batch_size
,
buffers
.
data
(),
stream_
,
nullptr
);
cudaStreamSynchronize
(
stream_
);
SetRuntimeBatch
(
batch_size
);
}
TensorRTEngine
::~
TensorRTEngine
()
{
cudaStreamSynchronize
(
stream_
);
// clean buffer
for
(
auto
&
buf
:
buffers_
)
{
if
(
buf
.
device
==
DeviceType
::
GPU
&&
buf
.
buffer
!=
nullptr
)
{
PADDLE_ENFORCE_EQ
(
0
,
cudaFree
(
buf
.
buffer
));
buf
.
buffer
=
nullptr
;
buf
.
max_size
=
0
;
}
}
}
void
TensorRTEngine
::
FreezeNetwork
()
{
void
TensorRTEngine
::
FreezeNetwork
()
{
VLOG
(
3
)
<<
"TRT to freeze network"
;
VLOG
(
3
)
<<
"TRT to freeze network"
;
PADDLE_ENFORCE
(
infer_builder_
!=
nullptr
,
PADDLE_ENFORCE
(
infer_builder_
!=
nullptr
,
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
8c171902
...
@@ -37,7 +37,9 @@ class TRTInt8Calibrator;
...
@@ -37,7 +37,9 @@ class TRTInt8Calibrator;
* There are two alternative ways to use it, one is to build from a paddle
* There are two alternative ways to use it, one is to build from a paddle
* protobuf model, another way is to manully construct the network.
* protobuf model, another way is to manully construct the network.
*/
*/
class
TensorRTEngine
:
public
EngineBase
{
class
TensorRTEngine
{
using
DescType
=
::
paddle
::
framework
::
proto
::
BlockDesc
;
public:
public:
// Weight is model parameter.
// Weight is model parameter.
class
Weight
{
class
Weight
{
...
@@ -56,24 +58,22 @@ class TensorRTEngine : public EngineBase {
...
@@ -56,24 +58,22 @@ class TensorRTEngine : public EngineBase {
nvinfer1
::
Weights
w_
;
nvinfer1
::
Weights
w_
;
};
};
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
stream
,
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
bool
enable_int8
=
false
,
bool
enable_int8
=
false
,
TRTInt8Calibrator
*
calibrator
=
nullptr
,
TRTInt8Calibrator
*
calibrator
=
nullptr
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
max_workspace_
(
max_workspace
),
stream_
(
stream
),
enable_int8_
(
enable_int8
),
enable_int8_
(
enable_int8
),
calibrator_
(
calibrator
),
calibrator_
(
calibrator
),
logger_
(
logger
)
{}
logger_
(
logger
)
{}
virtual
~
TensorRTEngine
();
~
TensorRTEngine
()
{}
// TODO(Superjomn) implement it later when graph segmentation is supported.
// TODO(Superjomn) implement it later when graph segmentation is supported.
void
Build
(
const
DescType
&
paddle_model
)
override
;
void
Build
(
const
DescType
&
paddle_model
);
void
Execute
(
int
batch_size
)
override
;
void
Execute
(
int
batch_size
,
std
::
vector
<
void
*>*
buffers
,
void
Execute
(
int
batch_size
,
std
::
vector
<
void
*>&
buffers
);
cudaStream_t
stream
);
// Initialize the inference network, so that TensorRT layers can add to this
// Initialize the inference network, so that TensorRT layers can add to this
// network.
// network.
...
@@ -98,8 +98,6 @@ class TensorRTEngine : public EngineBase {
...
@@ -98,8 +98,6 @@ class TensorRTEngine : public EngineBase {
// Check if the ITensor has been declared
// Check if the ITensor has been declared
bool
HasDeclared
(
const
std
::
string
&
name
);
bool
HasDeclared
(
const
std
::
string
&
name
);
cudaStream_t
stream
()
{
return
stream_
;
}
void
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
);
void
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
);
// Get an ITensor called name.
// Get an ITensor called name.
nvinfer1
::
ITensor
*
GetITensor
(
const
std
::
string
&
name
);
nvinfer1
::
ITensor
*
GetITensor
(
const
std
::
string
&
name
);
...
@@ -127,8 +125,6 @@ class TensorRTEngine : public EngineBase {
...
@@ -127,8 +125,6 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
// the max memory size the engine uses
int
max_workspace_
;
int
max_workspace_
;
cudaStream_t
stream_
;
bool
enable_int8_
;
bool
enable_int8_
;
TRTInt8Calibrator
*
calibrator_
;
TRTInt8Calibrator
*
calibrator_
;
// batch size of the current data, will be updated each Executation.
// batch size of the current data, will be updated each Executation.
...
@@ -136,7 +132,6 @@ class TensorRTEngine : public EngineBase {
...
@@ -136,7 +132,6 @@ class TensorRTEngine : public EngineBase {
nvinfer1
::
ILogger
&
logger_
;
nvinfer1
::
ILogger
&
logger_
;
std
::
vector
<
Buffer
>
buffers_
;
// max data size for the buffers.
// max data size for the buffers.
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
...
...
paddle/fluid/inference/tensorrt/test_engine.cc
浏览文件 @
8c171902
...
@@ -31,7 +31,7 @@ class TensorRTEngineTest : public ::testing::Test {
...
@@ -31,7 +31,7 @@ class TensorRTEngineTest : public ::testing::Test {
void
SetUp
()
override
{
void
SetUp
()
override
{
ctx_
=
new
platform
::
CUDADeviceContext
(
platform
::
CUDAPlace
(
0
));
ctx_
=
new
platform
::
CUDADeviceContext
(
platform
::
CUDAPlace
(
0
));
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
,
ctx_
->
stream
()
);
engine_
=
new
TensorRTEngine
(
10
,
1
<<
10
);
engine_
->
InitNetwork
();
engine_
->
InitNetwork
();
}
}
...
@@ -88,7 +88,7 @@ TEST_F(TensorRTEngineTest, add_layer) {
...
@@ -88,7 +88,7 @@ TEST_F(TensorRTEngineTest, add_layer) {
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
LOG
(
INFO
)
<<
"to execute"
;
LOG
(
INFO
)
<<
"to execute"
;
engine_
->
Execute
(
1
,
buffers
);
engine_
->
Execute
(
1
,
&
buffers
,
ctx_
->
stream
()
);
LOG
(
INFO
)
<<
"to get output"
;
LOG
(
INFO
)
<<
"to get output"
;
GetOutput
(
&
y_cpu
);
GetOutput
(
&
y_cpu
);
...
@@ -128,7 +128,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
...
@@ -128,7 +128,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
1
,
buffers
);
engine_
->
Execute
(
1
,
&
buffers
,
ctx_
->
stream
()
);
LOG
(
INFO
)
<<
"to get output"
;
LOG
(
INFO
)
<<
"to get output"
;
GetOutput
(
&
y_cpu
);
GetOutput
(
&
y_cpu
);
...
@@ -175,7 +175,7 @@ TEST_F(TensorRTEngineTest, test_conv2d) {
...
@@ -175,7 +175,7 @@ TEST_F(TensorRTEngineTest, test_conv2d) {
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
2
,
buffers
);
engine_
->
Execute
(
2
,
&
buffers
,
ctx_
->
stream
()
);
LOG
(
INFO
)
<<
"to get output"
;
LOG
(
INFO
)
<<
"to get output"
;
GetOutput
(
&
y_cpu
);
GetOutput
(
&
y_cpu
);
...
@@ -214,7 +214,7 @@ TEST_F(TensorRTEngineTest, test_pool2d) {
...
@@ -214,7 +214,7 @@ TEST_F(TensorRTEngineTest, test_pool2d) {
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
2
,
buffers
);
engine_
->
Execute
(
2
,
&
buffers
,
ctx_
->
stream
()
);
LOG
(
INFO
)
<<
"to get output"
;
LOG
(
INFO
)
<<
"to get output"
;
GetOutput
(
&
y_cpu
);
GetOutput
(
&
y_cpu
);
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
8c171902
...
@@ -142,10 +142,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -142,10 +142,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
LOG_FIRST_N
(
INFO
,
1
)
<<
"The TRT engine: "
<<
engine_key_
LOG_FIRST_N
(
INFO
,
1
)
<<
"The TRT engine: "
<<
engine_key_
<<
" is running calibration trt int8... "
;
<<
" is running calibration trt int8... "
;
int
runtime_batch
=
1
;
int
runtime_batch
=
1
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
!
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Has
(
engine_key_
))
{
if
(
!
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Has
(
engine_key_
))
{
TRTCalibratorEngine
*
calib_res
=
TRTCalibratorEngine
*
calib_res
=
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Create
(
engine_key_
);
Singleton
<
TRTCalibratorEngineManager
>::
Global
().
Create
(
engine_key_
);
...
@@ -162,10 +158,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -162,10 +158,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_buffers
,
runtime_batch
,
engine_key_
,
dev_place
));
calib_buffers
,
runtime_batch
,
engine_key_
,
dev_place
));
calib_res
->
thr_
.
reset
(
new
std
::
thread
([
&
]()
{
calib_res
->
thr_
.
reset
(
new
std
::
thread
([
&
]()
{
calib_res
->
engine_
.
reset
(
calib_res
->
engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
enable_int8_
,
enable_int8_
,
calib_res
->
calib_
.
get
()));
calib_res
->
calib_
.
get
()));
VLOG
(
3
)
<<
"start the calib trt engine thread"
;
VLOG
(
3
)
<<
"start the calib trt engine thread"
;
Prepare
(
scope
,
dev_place
,
calib_res
->
engine_
.
get
());
Prepare
(
scope
,
calib_res
->
engine_
.
get
());
}));
}));
}
}
...
@@ -253,22 +249,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -253,22 +249,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
PADDLE_ENFORCE_LE
(
runtime_batch
,
max_batch_size_
);
PADDLE_ENFORCE_LE
(
runtime_batch
,
max_batch_size_
);
// Execute the engine.
// Execute the engine.
engine
->
Execute
(
runtime_batch
,
buffers
);
engine
->
Execute
(
runtime_batch
,
&
buffers
,
stream
);
cudaStreamSynchronize
(
stream
);
cudaStreamSynchronize
(
stream
);
}
}
TensorRTEngine
*
GetEngine
(
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
GetEngine
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
const
platform
::
Place
&
dev_place
)
const
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
).
stream
();
if
(
trt_engine_
.
get
()
==
nullptr
)
{
if
(
trt_engine_
.
get
()
==
nullptr
)
{
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
trt_engine_
.
reset
(
new
TensorRTEngine
(
max_batch_size_
,
workspace_size_
,
stream
,
enable_int8_
,
enable_int8_
,
calibrator_
.
get
()));
calibrator_
.
get
()));
if
(
true
)
{
if
(
true
)
{
Prepare
(
scope
,
dev_place
,
trt_engine_
.
get
());
Prepare
(
scope
,
trt_engine_
.
get
());
}
else
{
}
else
{
// create static engine
// create static engine
}
}
...
@@ -276,20 +267,19 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -276,20 +267,19 @@ class TensorRTEngineOp : public framework::OperatorBase {
return
trt_engine_
.
get
();
return
trt_engine_
.
get
();
}
}
void
Prepare
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
,
void
Prepare
(
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
)
const
{
TensorRTEngine
*
engine
)
const
{
LOG
(
INFO
)
<<
"Prepare TRT engine (Optimize model structure, Select OP "
LOG
(
INFO
)
<<
"Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time."
;
"kernel etc). This process may cost a lot of time."
;
framework
::
proto
::
BlockDesc
block_desc
;
framework
::
proto
::
BlockDesc
block_desc
;
block_desc
.
ParseFromString
(
Attr
<
std
::
string
>
(
"subgraph"
));
block_desc
.
ParseFromString
(
Attr
<
std
::
string
>
(
"subgraph"
));
framework
::
BlockDesc
block
(
nullptr
/*programdesc*/
,
&
block_desc
);
std
::
vector
<
std
::
string
>
output_maps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
engine
->
InitNetwork
();
engine
->
InitNetwork
();
framework
::
BlockDesc
block
(
nullptr
/*programdesc*/
,
&
block_desc
);
VLOG
(
4
)
<<
"parsed var size "
<<
block
.
AllVars
().
size
();
VLOG
(
4
)
<<
"parsed var size "
<<
block
.
AllVars
().
size
();
std
::
vector
<
std
::
string
>
output_maps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"output_name_mapping"
);
// Add inputs
// Add inputs
VLOG
(
4
)
<<
"declare inputs"
;
VLOG
(
4
)
<<
"declare inputs"
;
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
...
@@ -306,12 +296,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -306,12 +296,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
PADDLE_ENFORCE
(
var
,
"no variable called %s"
,
input
);
PADDLE_ENFORCE
(
var
,
"no variable called %s"
,
input
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
"TensorRT engine only takes LoDTensor as input"
);
"TensorRT engine only takes LoDTensor as input"
);
engine
->
DeclareInput
(
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
Vec2TRT_Dims
(
t_shape
));
Vec2TRT_Dims
(
t_shape
));
}
}
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
()
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
()
.
ConvertBlock
(
block_desc
,
param_names_
,
scope
,
engine
);
.
ConvertBlock
(
block_desc
,
param_names_
,
scope
,
engine
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录