Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d8b8c2d8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d8b8c2d8
编写于
3月 09, 2023
作者:
X
xiaoxiaohehe001
提交者:
GitHub
3月 09, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] Support split sectionslist and axis = 0 input of trt . (#50957)
* split_list
上级
ccfe7681
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
195 addition
and
37 deletion
+195
-37
paddle/fluid/inference/tensorrt/convert/split_op.cc
paddle/fluid/inference/tensorrt/convert/split_op.cc
+38
-22
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+5
-3
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_split.py
...id/tests/unittests/ir/inference/test_trt_convert_split.py
+152
-12
未找到文件。
paddle/fluid/inference/tensorrt/convert/split_op.cc
浏览文件 @
d8b8c2d8
...
@@ -29,15 +29,15 @@ class SplitOpConverter : public OpConverter {
...
@@ -29,15 +29,15 @@ class SplitOpConverter : public OpConverter {
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
// Declare inputs
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
inputs
=
op_desc
.
Inputs
();
auto
input_dims
=
input
->
getDimensions
();
auto
input_dims
=
input
->
getDimensions
();
size_
t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
in
t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
// Get Attrs
// Get Attrs
int
axis
=
PADDLE_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"axis"
));
int
axis
=
PADDLE_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"axis"
));
int
num
=
0
;
std
::
vector
<
int
>
output_lengths
=
std
::
vector
<
int
>
output_lengths
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"sections"
));
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"sections"
));
int
num
=
0
;
if
(
op_desc
.
HasAttr
(
"num"
))
{
if
(
op_desc
.
HasAttr
(
"num"
))
{
num
=
PADDLE_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"num"
));
num
=
PADDLE_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"num"
));
}
}
...
@@ -50,19 +50,34 @@ class SplitOpConverter : public OpConverter {
...
@@ -50,19 +50,34 @@ class SplitOpConverter : public OpConverter {
axis
+=
(
axis
<
0
)
?
input_dims
.
nbDims
:
-
1
;
axis
+=
(
axis
<
0
)
?
input_dims
.
nbDims
:
-
1
;
}
}
bool
in_axis_dim_dynamic
=
false
;
bool
in_axis_dim_dynamic
=
false
;
nvinfer1
::
ITensor
*
avg_len_tensor
=
nullptr
;
bool
sections_tensor_list
=
false
;
nvinfer1
::
ITensor
*
sections_tensor
=
nullptr
;
// need infer output_lengths
// need infer output_lengths
if
(
num
>
0
&&
output_lengths
.
empty
())
{
if
(
inputs
.
find
(
"SectionsTensorList"
)
!=
inputs
.
end
()
&&
op_desc
.
Input
(
"SectionsTensorList"
).
size
()
>=
1
)
{
int32_t
sections_size
=
op_desc
.
Input
(
"SectionsTensorList"
).
size
();
std
::
vector
<
nvinfer1
::
ITensor
*>
sections_tensors
;
for
(
int32_t
i
=
0
;
i
<
sections_size
;
++
i
)
{
sections_tensors
.
push_back
(
engine_
->
GetITensor
(
op_desc
.
Input
(
"SectionsTensorList"
)[
i
]));
}
sections_tensor
=
Concat
(
sections_tensors
);
sections_tensor_list
=
true
;
}
else
if
(
!
output_lengths
.
empty
())
{
sections_tensor
=
Add1DConstantLayer
(
output_lengths
);
}
else
if
(
num
>
0
&&
output_lengths
.
empty
())
{
if
(
input_dims
.
d
[
axis
]
>
0
)
{
if
(
input_dims
.
d
[
axis
]
>
0
)
{
int64_t
in_axis_dim
=
input_dims
.
d
[
axis
];
int64_t
in_axis_dim
=
input_dims
.
d
[
axis
];
size_t
out_axis_dim
=
in_axis_dim
/
num
;
size_t
out_axis_dim
=
in_axis_dim
/
num
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
output_lengths
.
push_back
(
out_axis_dim
);
output_lengths
.
push_back
(
out_axis_dim
);
}
}
sections_tensor
=
Add1DConstantLayer
(
output_lengths
);
}
else
{
}
else
{
in_axis_dim_dynamic
=
true
;
in_axis_dim_dynamic
=
true
;
auto
*
num_tensor
=
Add1DConstantLayer
(
num
);
auto
*
num_tensor
=
Add1DConstantLayer
(
num
);
avg_len
_tensor
=
sections
_tensor
=
Div
(
GetEleTensorOfShape
(
shape_tensor
,
axis
),
num_tensor
);
Div
(
GetEleTensorOfShape
(
shape_tensor
,
axis
),
num_tensor
);
}
}
}
}
...
@@ -79,20 +94,20 @@ class SplitOpConverter : public OpConverter {
...
@@ -79,20 +94,20 @@ class SplitOpConverter : public OpConverter {
std
::
iota
(
gather_indices
.
begin
(),
gather_indices
.
end
(),
0
);
std
::
iota
(
gather_indices
.
begin
(),
gather_indices
.
end
(),
0
);
gather_indices
[
axis
]
=
gather_indices
.
size
();
gather_indices
[
axis
]
=
gather_indices
.
size
();
std
::
vector
<
int32_t
>
zeros
(
trt_step_dims
.
nbDims
,
0
);
std
::
vector
<
int32_t
>
zeros
(
trt_step_dims
.
nbDims
,
0
);
auto
*
zeros_tensor
=
Add1DConstantLayer
(
zeros
);
std
::
vector
<
int32_t
>
stride
(
trt_step_dims
.
nbDims
,
1
);
auto
zeros_tensor
=
Add1DConstantLayer
(
zeros
);
auto
stride_tensor
=
Add1DConstantLayer
(
stride
);
// input : [N,C,H,W]
// input : [N,C,H,W]
int
start_point
=
0
;
nvinfer1
::
ITensor
*
start_point_tensor
=
zeros_tensor
;
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
nvinfer1
::
ITensor
*
this_len_tensor
=
zeros_tensor
;
nvinfer1
::
ITensor
*
this_len_tensor
=
nullptr
;
for
(
int
i
=
0
;
i
<
output_num
;
i
++
)
{
nvinfer1
::
ITensor
*
start_point_tensor
=
nullptr
;
if
(
sections_tensor_list
||
!
in_axis_dim_dynamic
)
{
if
(
!
in_axis_dim_dynamic
)
{
start_point_tensor
=
Sum
(
start_point_tensor
,
this_len_tensor
);
this_len_tensor
=
Add1DConstantLayer
(
output_lengths
[
i
]);
this_len_tensor
=
Gather
(
sections_tensor
,
std
::
vector
<
int32_t
>
{
i
});
start_point_tensor
=
Add1DConstantLayer
(
start_point
);
start_point
+=
output_lengths
[
i
];
}
else
{
}
else
{
this_len_tensor
=
avg_len
_tensor
;
this_len_tensor
=
sections
_tensor
;
auto
*
i_tensor
=
Add1DConstantLayer
(
static_cast
<
int
>
(
i
));
auto
*
i_tensor
=
Add1DConstantLayer
(
static_cast
<
int
>
(
i
));
start_point_tensor
=
Prod
(
i_tensor
,
avg_len
_tensor
);
start_point_tensor
=
Prod
(
i_tensor
,
sections
_tensor
);
}
}
std
::
vector
<
nvinfer1
::
ITensor
*>
concat_inputs1
=
{
zeros_tensor
,
std
::
vector
<
nvinfer1
::
ITensor
*>
concat_inputs1
=
{
zeros_tensor
,
...
@@ -104,11 +119,12 @@ class SplitOpConverter : public OpConverter {
...
@@ -104,11 +119,12 @@ class SplitOpConverter : public OpConverter {
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Slice
,
Slice
,
*
input
,
*
input
,
trt_step_dims
,
nvinfer1
::
Dims
{}
,
trt_step_dims
,
nvinfer1
::
Dims
{}
,
trt_step_dims
);
nvinfer1
::
Dims
{}
);
layer
->
setInput
(
1
,
*
start_tensor
);
layer
->
setInput
(
1
,
*
start_tensor
);
layer
->
setInput
(
2
,
*
size_tensor
);
layer
->
setInput
(
2
,
*
size_tensor
);
layer
->
setInput
(
3
,
*
stride_tensor
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
i
];
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
i
];
RreplenishLayerAndOutput
(
layer
,
"split"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
layer
,
"split"
,
{
output_name
},
test_mode
);
...
@@ -124,7 +140,7 @@ class SplitOpConverter : public OpConverter {
...
@@ -124,7 +140,7 @@ class SplitOpConverter : public OpConverter {
for
(
int
i
=
0
;
i
<
trt_step_dims
.
nbDims
;
i
++
)
trt_step_dims
.
d
[
i
]
=
1
;
for
(
int
i
=
0
;
i
<
trt_step_dims
.
nbDims
;
i
++
)
trt_step_dims
.
d
[
i
]
=
1
;
// input : [C,H,W]
// input : [C,H,W]
for
(
size_
t
i
=
0
;
i
<
output_num
;
i
++
)
{
for
(
in
t
i
=
0
;
i
<
output_num
;
i
++
)
{
trt_start_dims
.
d
[
axis
]
=
std
::
accumulate
(
trt_start_dims
.
d
[
axis
]
=
std
::
accumulate
(
output_lengths
.
begin
(),
output_lengths
.
begin
()
+
i
,
0
);
output_lengths
.
begin
(),
output_lengths
.
begin
()
+
i
,
0
);
trt_size_dims
.
d
[
axis
]
=
output_lengths
[
i
];
trt_size_dims
.
d
[
axis
]
=
output_lengths
[
i
];
...
@@ -153,7 +169,7 @@ class SplitOpConverter : public OpConverter {
...
@@ -153,7 +169,7 @@ class SplitOpConverter : public OpConverter {
layer
=
engine_
->
AddPluginV2Ext
(
&
input
,
1
,
plugin
);
layer
=
engine_
->
AddPluginV2Ext
(
&
input
,
1
,
plugin
);
}
}
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
output_names
;
for
(
size_
t
i
=
0
;
i
<
output_num
;
i
++
)
{
for
(
in
t
i
=
0
;
i
<
output_num
;
i
++
)
{
output_names
.
push_back
(
op_desc
.
Output
(
"Out"
)[
i
]);
output_names
.
push_back
(
op_desc
.
Output
(
"Out"
)[
i
]);
}
}
RreplenishLayerAndOutput
(
layer
,
"split"
,
output_names
,
test_mode
);
RreplenishLayerAndOutput
(
layer
,
"split"
,
output_names
,
test_mode
);
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
d8b8c2d8
...
@@ -1079,7 +1079,9 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -1079,7 +1079,9 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if
(
split_inputs
.
find
(
"SectionsTensorList"
)
!=
split_inputs
.
end
())
{
if
(
split_inputs
.
find
(
"SectionsTensorList"
)
!=
split_inputs
.
end
())
{
if
(
desc
.
Input
(
"SectionsTensorList"
).
size
()
>=
1
)
{
if
(
desc
.
Input
(
"SectionsTensorList"
).
size
()
>=
1
)
{
return
false
;
if
(
!
with_dynamic_shape
)
{
return
false
;
}
}
}
}
}
if
(
!
desc
.
HasAttr
(
"axis"
))
{
if
(
!
desc
.
HasAttr
(
"axis"
))
{
...
@@ -1087,9 +1089,9 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -1087,9 +1089,9 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
int
axis
=
PADDLE_GET_CONST
(
int
,
desc
.
GetAttr
(
"axis"
));
int
axis
=
PADDLE_GET_CONST
(
int
,
desc
.
GetAttr
(
"axis"
));
if
(
axis
==
0
)
{
if
(
!
with_dynamic_shape
&&
axis
==
0
)
{
VLOG
(
3
)
<<
"Invalid split axis. Split on batch is not supported in "
VLOG
(
3
)
<<
"Invalid split axis. Split on batch is not supported in "
"TensorRT"
;
"TensorRT
with static shape
"
;
return
false
;
return
false
;
}
}
auto
*
block
=
desc
.
Block
();
auto
*
block
=
desc
.
Block
();
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_split.py
浏览文件 @
d8b8c2d8
...
@@ -70,6 +70,14 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
...
@@ -70,6 +70,14 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
else
:
else
:
return
False
return
False
if
self
.
dims
==
2
:
if
self
.
batch
!=
3
:
return
False
if
len
(
attrs
[
0
][
'sections'
])
!=
0
and
attrs
[
0
][
'axis'
]
==
0
:
if
self
.
dims
!=
2
or
self
.
batch
!=
3
:
return
False
return
True
return
True
def
sample_program_configs
(
self
):
def
sample_program_configs
(
self
):
...
@@ -81,7 +89,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
...
@@ -81,7 +89,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
elif
self
.
dims
==
2
:
elif
self
.
dims
==
2
:
return
np
.
random
.
random
([
batch
,
24
]).
astype
(
np
.
float32
)
return
np
.
random
.
random
([
batch
,
24
]).
astype
(
np
.
float32
)
elif
self
.
dims
==
1
:
elif
self
.
dims
==
1
:
return
np
.
random
.
random
([
24
]).
astype
(
np
.
floa
t32
)
return
np
.
random
.
random
([
24
]).
astype
(
np
.
in
t32
)
def
generate_AxisTensor
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
def
generate_AxisTensor
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
ones
([
1
]).
astype
(
np
.
int32
)
return
np
.
ones
([
1
]).
astype
(
np
.
int32
)
...
@@ -204,13 +212,9 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
...
@@ -204,13 +212,9 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
}
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"split_input"
:
[
1
,
3
,
24
]}
self
.
dynamic_shape
.
opt_input_shape
=
{
"split_input"
:
[
1
,
3
,
24
]}
elif
self
.
dims
==
2
:
elif
self
.
dims
==
2
:
self
.
dynamic_shape
.
min_input_shape
=
{
self
.
dynamic_shape
.
min_input_shape
=
{
"split_input"
:
[
3
,
24
]}
"split_input"
:
[
1
,
24
-
1
]
self
.
dynamic_shape
.
max_input_shape
=
{
"split_input"
:
[
3
,
24
]}
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"split_input"
:
[
3
,
24
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"split_input"
:
[
9
,
24
+
1
]
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"split_input"
:
[
1
,
24
]}
elif
self
.
dims
==
1
:
elif
self
.
dims
==
1
:
self
.
dynamic_shape
.
min_input_shape
=
{
"split_input"
:
[
24
-
1
]}
self
.
dynamic_shape
.
min_input_shape
=
{
"split_input"
:
[
24
-
1
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"split_input"
:
[
24
+
1
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"split_input"
:
[
24
+
1
]}
...
@@ -223,15 +227,21 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
...
@@ -223,15 +227,21 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
if
len
(
program_config
.
outputs
)
==
2
:
if
len
(
program_config
.
outputs
)
==
2
:
if
attrs
[
0
][
'axis'
]
!=
0
:
if
dynamic_shape
:
return
1
,
3
return
1
,
3
else
:
else
:
return
0
,
4
if
attrs
[
0
][
'axis'
]
!=
0
:
return
1
,
3
else
:
return
0
,
4
else
:
else
:
if
attrs
[
0
][
'axis'
]
!=
0
:
if
dynamic_shape
:
return
1
,
4
return
1
,
4
else
:
else
:
return
0
,
5
if
attrs
[
0
][
'axis'
]
!=
0
:
return
1
,
4
else
:
return
0
,
5
attrs
=
[
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
...
@@ -276,5 +286,135 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
...
@@ -276,5 +286,135 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
self
.
run_test
()
self
.
run_test
()
class
TrtConvertSplitTest2
(
TrtLayerAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
return
True
def
sample_program_configs
(
self
):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
3
,
3
,
3
,
24
]).
astype
(
np
.
float32
)
for
sections
in
[
[
-
1
,
-
1
,
-
1
],
[
1
,
1
,
1
],
]:
for
num
in
[
0
]:
for
axis
in
[
0
,
1
]:
dics
=
[
{
"sections"
:
sections
,
"num"
:
num
,
"axis"
:
axis
,
}
]
dics_intput
=
[
{
"X"
:
[
"split_input"
],
"SectionsTensorList"
:
[
"shapeT1_data"
,
"shapeT2_data"
,
"shapeT3_data"
,
],
},
]
ops_config
=
[
{
"op_type"
:
"fill_constant"
,
"op_inputs"
:
{},
"op_outputs"
:
{
"Out"
:
[
"shapeT1_data"
]},
"op_attrs"
:
{
"dtype"
:
2
,
"str_value"
:
"1"
,
"shape"
:
[
1
],
},
},
{
"op_type"
:
"fill_constant"
,
"op_inputs"
:
{},
"op_outputs"
:
{
"Out"
:
[
"shapeT2_data"
]},
"op_attrs"
:
{
"dtype"
:
2
,
"str_value"
:
"1"
,
"shape"
:
[
1
],
},
},
{
"op_type"
:
"fill_constant"
,
"op_inputs"
:
{},
"op_outputs"
:
{
"Out"
:
[
"shapeT3_data"
]},
"op_attrs"
:
{
"dtype"
:
2
,
"str_value"
:
"1"
,
"shape"
:
[
1
],
},
},
{
"op_type"
:
"split"
,
"op_inputs"
:
dics_intput
[
0
],
"op_outputs"
:
{
"Out"
:
[
"output_var0"
,
"output_var1"
,
"output_var2"
,
]
},
"op_attrs"
:
dics
[
0
],
},
]
ops
=
self
.
generate_op_config
(
ops_config
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"split_input"
:
TensorConfig
(
data_gen
=
partial
(
generate_input1
,
dics
)
)
},
outputs
=
[
"output_var0"
,
"output_var1"
,
"output_var2"
],
)
yield
program_config
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"split_input"
:
[
1
,
3
,
3
,
24
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"split_input"
:
[
9
,
3
,
3
,
24
]}
self
.
dynamic_shape
.
opt_input_shape
=
{
"split_input"
:
[
3
,
3
,
3
,
24
]}
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
max_input_shape
=
{}
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
if
dynamic_shape
:
return
1
,
4
return
0
,
5
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
self
.
trt_param
.
max_batch_size
=
9
# for dynamic_shape
generate_dynamic_shape
(
attrs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-3
def
add_skip_trt_case
(
self
):
pass
def
test
(
self
):
self
.
add_skip_trt_case
()
self
.
run_test
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录