Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a8078bbd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a8078bbd
编写于
1月 31, 2023
作者:
L
LiYuRio
提交者:
GitHub
1月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multi fetch (#50070)
上级
db83b53a
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
120 addition
and
33 deletion
+120
-33
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+19
-6
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+3
-1
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+10
-4
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+5
-2
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+24
-0
python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py
...d/tests/unittests/test_fleet_executor_cond_interceptor.py
+59
-20
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
a8078bbd
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include <algorithm>
#include <algorithm>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...
@@ -24,6 +25,7 @@
...
@@ -24,6 +25,7 @@
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -55,23 +57,34 @@ void Carrier::Init(
...
@@ -55,23 +57,34 @@ void Carrier::Init(
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
rank_
=
rank
;
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
place_
=
place
;
place_
=
place
;
root_scope_
=
scope
;
root_scope_
=
scope
;
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
bool
need_create_scope
=
micro_scope_list
.
empty
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
root_scope_
,
platform
::
errors
::
InvalidArgument
(
"root_scope can not be nullptr"
));
platform
::
errors
::
InvalidArgument
(
"root_scope can not be nullptr"
));
if
(
need_create_scope
)
{
minibatch_scope_
=
&
root_scope_
->
NewScope
();
minibatch_scope_
=
&
root_scope_
->
NewScope
();
microbatch_scopes_
.
resize
(
num_micro_batches
);
microbatch_scopes_
.
resize
(
num_micro_batches
);
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
}
}
}
else
{
microbatch_scopes_
=
micro_scope_list
;
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
}
}
// Add source and sink interceptor id to rank
// Add source and sink interceptor id to rank
interceptor_id_to_rank_
.
emplace
(
SOURCE_ID
,
rank
);
interceptor_id_to_rank_
.
emplace
(
SOURCE_ID
,
rank
);
interceptor_id_to_rank_
.
emplace
(
SINK_ID
,
rank
);
interceptor_id_to_rank_
.
emplace
(
SINK_ID
,
rank
);
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
a8078bbd
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
...
@@ -60,7 +61,8 @@ class Carrier final {
...
@@ -60,7 +61,8 @@ class Carrier final {
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
void
CopyParameters
(
void
CopyParameters
(
int
microbatch_id
,
int
microbatch_id
,
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
a8078bbd
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include <algorithm>
#include <algorithm>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
@@ -24,6 +25,7 @@
...
@@ -24,6 +25,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/variable.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -59,7 +61,8 @@ void FleetExecutor::Init(
...
@@ -59,7 +61,8 @@ void FleetExecutor::Init(
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
0
,
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -144,7 +147,8 @@ void FleetExecutor::Init(
...
@@ -144,7 +147,8 @@ void FleetExecutor::Init(
place
,
place
,
num_micro_batches
,
num_micro_batches
,
program_desc
,
program_desc
,
inference_root_scope_vars
);
inference_root_scope_vars
,
micro_scope_list
);
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
}
...
@@ -154,7 +158,8 @@ void FleetExecutor::InitCarrier(
...
@@ -154,7 +158,8 @@ void FleetExecutor::InitCarrier(
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_node
(),
runtime_graph_
->
interceptor_id_to_node
(),
...
@@ -162,7 +167,8 @@ void FleetExecutor::InitCarrier(
...
@@ -162,7 +167,8 @@ void FleetExecutor::InitCarrier(
scope
,
scope
,
num_micro_batches
,
num_micro_batches
,
place
,
place
,
inference_root_scope_vars
);
inference_root_scope_vars
,
micro_scope_list
);
}
}
void
FleetExecutor
::
InitMessageBus
()
{
void
FleetExecutor
::
InitMessageBus
()
{
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
a8078bbd
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -45,7 +46,8 @@ class FleetExecutor final {
...
@@ -45,7 +46,8 @@ class FleetExecutor final {
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
void
Run
(
const
std
::
string
&
carrier_id
);
void
Run
(
const
std
::
string
&
carrier_id
);
private:
private:
...
@@ -57,7 +59,8 @@ class FleetExecutor final {
...
@@ -57,7 +59,8 @@ class FleetExecutor final {
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
FleetExecutorDesc
exe_desc_
;
FleetExecutorDesc
exe_desc_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
...
...
python/paddle/fluid/executor.py
浏览文件 @
a8078bbd
...
@@ -2464,6 +2464,7 @@ class Executor:
...
@@ -2464,6 +2464,7 @@ class Executor:
program
=
None
,
program
=
None
,
scope
=
None
,
scope
=
None
,
fleet_opt
=
None
,
fleet_opt
=
None
,
micro_scope_list
=
[],
with_standalone_executor
=
False
,
with_standalone_executor
=
False
,
):
):
num_micro_batches
=
(
num_micro_batches
=
(
...
@@ -2532,6 +2533,7 @@ class Executor:
...
@@ -2532,6 +2533,7 @@ class Executor:
fleet_opt
[
'task_id_to_rank'
]
=
task_id_to_rank
fleet_opt
[
'task_id_to_rank'
]
=
task_id_to_rank
place
=
core
.
Place
()
place
=
core
.
Place
()
place
.
set_place
(
self
.
place
)
place
.
set_place
(
self
.
place
)
# NOTE: the last argument is used to force create some vars in root scope,
# NOTE: the last argument is used to force create some vars in root scope,
# won't be used during train.
# won't be used during train.
self
.
_fleet_executor
.
init
(
self
.
_fleet_executor
.
init
(
...
@@ -2543,6 +2545,7 @@ class Executor:
...
@@ -2543,6 +2545,7 @@ class Executor:
tasks
,
tasks
,
task_id_to_rank
,
task_id_to_rank
,
[],
[],
micro_scope_list
,
)
)
def
_run_using_fleet_executor
(
def
_run_using_fleet_executor
(
...
@@ -2624,11 +2627,20 @@ class Executor:
...
@@ -2624,11 +2627,20 @@ class Executor:
)
)
fetch_task
.
set_program
(
fetch_program
)
fetch_task
.
set_program
(
fetch_program
)
micro_scope_list
=
[]
if
(
"inference_generation"
in
fleet_opt
and
fleet_opt
[
"inference_generation"
]
):
for
i
in
range
(
int
(
fleet_opt
[
"num_micro_batches"
])):
micro_scope_list
.
append
(
cached_scope
.
new_scope
())
self
.
_prepare_fleet_executor_carrier
(
self
.
_prepare_fleet_executor_carrier
(
cache_key
,
cache_key
,
program
=
cached_program
,
program
=
cached_program
,
scope
=
cached_scope
,
scope
=
cached_scope
,
fleet_opt
=
fleet_opt
,
fleet_opt
=
fleet_opt
,
micro_scope_list
=
micro_scope_list
,
with_standalone_executor
=
with_standalone_executor
,
with_standalone_executor
=
with_standalone_executor
,
)
)
...
@@ -2653,6 +2665,18 @@ class Executor:
...
@@ -2653,6 +2665,18 @@ class Executor:
self
.
_fleet_executor
.
run
(
cache_key
)
self
.
_fleet_executor
.
run
(
cache_key
)
if
"fetch_var"
in
fleet_opt
:
# If we speed up the generation in evaluation, we need to generate
# multiple queries at the same time. Each query will in separate scope in order
# not mix up. It indicate that final result will in multiple scopes and need to
# fetch each.
result_list
=
[]
for
scope
in
micro_scope_list
:
for
var
in
fleet_opt
[
"fetch_var"
]:
tensor
=
core
.
get_variable_tensor
(
scope
,
var
)
result_list
.
append
(
as_numpy
(
tensor
))
return
result_list
if
fetch_list
:
if
fetch_list
:
arr
=
cached_scope
.
find_var
(
fetch_var_name
).
get_fetch_list
()
arr
=
cached_scope
.
find_var
(
fetch_var_name
).
get_fetch_list
()
tensors
=
arr
.
_move_to_list
()
tensors
=
arr
.
_move_to_list
()
...
...
python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py
浏览文件 @
a8078bbd
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
unittest
import
unittest
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
from
paddle.distributed.fleet.fleet_executor_utils
import
TaskNode
from
paddle.distributed.fleet.fleet_executor_utils
import
TaskNode
...
@@ -21,13 +23,26 @@ from paddle.distributed.fleet.fleet_executor_utils import TaskNode
...
@@ -21,13 +23,26 @@ from paddle.distributed.fleet.fleet_executor_utils import TaskNode
paddle
.
enable_static
()
paddle
.
enable_static
()
def
cond
(
i
,
ten
):
def
cond
(
i
,
ten
,
data
):
return
i
<
ten
return
i
<
ten
def
body
(
i
,
ten
):
def
body
(
i
,
ten
,
data
):
i
=
i
+
1
i
=
i
+
1
return
[
i
,
ten
]
data
=
data
+
1
return
[
i
,
ten
,
data
]
num_micro_batches
=
3
def
batch_generator_creator
():
def
__reader__
():
for
i
in
range
(
num_micro_batches
):
data
=
np
.
full
(
shape
=
[
1
,
1
],
fill_value
=
i
,
dtype
=
np
.
float32
)
yield
data
return
__reader__
class
TestFleetExecutor
(
unittest
.
TestCase
):
class
TestFleetExecutor
(
unittest
.
TestCase
):
...
@@ -41,7 +56,16 @@ class TestFleetExecutor(unittest.TestCase):
...
@@ -41,7 +56,16 @@ class TestFleetExecutor(unittest.TestCase):
ten
=
paddle
.
full
(
ten
=
paddle
.
full
(
shape
=
[
1
],
fill_value
=
10
,
dtype
=
'int64'
shape
=
[
1
],
fill_value
=
10
,
dtype
=
'int64'
)
# loop length
)
# loop length
i
,
ten
=
paddle
.
static
.
nn
.
while_loop
(
cond
,
body
,
[
i
,
ten
])
data
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
1
])
loader
=
paddle
.
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
data
],
capacity
=
num_micro_batches
*
4
,
iterable
=
False
)
loader
.
set_batch_generator
(
batch_generator_creator
(),
paddle
.
CUDAPlace
(
0
)
)
paddle
.
static
.
nn
.
while_loop
(
cond
,
body
,
[
i
,
ten
,
data
])
program_a
=
paddle
.
static
.
Program
()
program_a
=
paddle
.
static
.
Program
()
program_b
=
paddle
.
static
.
Program
()
program_b
=
paddle
.
static
.
Program
()
...
@@ -49,6 +73,15 @@ class TestFleetExecutor(unittest.TestCase):
...
@@ -49,6 +73,15 @@ class TestFleetExecutor(unittest.TestCase):
for
var_name
in
main_program
.
block
(
0
).
vars
:
for
var_name
in
main_program
.
block
(
0
).
vars
:
if
var_name
!=
"_generated_var_0"
:
if
var_name
!=
"_generated_var_0"
:
var
=
main_program
.
block
(
0
).
var
(
var_name
)
var
=
main_program
.
block
(
0
).
var
(
var_name
)
if
(
var_name
==
"create_py_reader_0"
or
var_name
==
"double_buffer_0"
):
program_a
.
block
(
0
).
create_var
(
name
=
var_name
,
persistable
=
var
.
persistable
,
)
else
:
program_a
.
block
(
0
).
create_var
(
program_a
.
block
(
0
).
create_var
(
name
=
var_name
,
name
=
var_name
,
shape
=
var
.
shape
,
shape
=
var
.
shape
,
...
@@ -89,7 +122,6 @@ class TestFleetExecutor(unittest.TestCase):
...
@@ -89,7 +122,6 @@ class TestFleetExecutor(unittest.TestCase):
)
)
cond_var_name
=
"tmp_0"
cond_var_name
=
"tmp_0"
num_micro_batches
=
3
task_a
=
TaskNode
(
task_a
=
TaskNode
(
0
,
0
,
...
@@ -159,12 +191,19 @@ class TestFleetExecutor(unittest.TestCase):
...
@@ -159,12 +191,19 @@ class TestFleetExecutor(unittest.TestCase):
task_e
.
task_id
():
0
,
task_e
.
task_id
():
0
,
},
},
'num_micro_batches'
:
num_micro_batches
,
'num_micro_batches'
:
num_micro_batches
,
'inference_generation'
:
True
,
'fetch_var'
:
[
'x'
],
},
},
}
}
place
=
paddle
.
fluid
.
CUDAPlace
(
0
)
place
=
paddle
.
CUDAPlace
(
0
)
exe
=
paddle
.
fluid
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
main_program
)
loader
.
start
()
res
=
exe
.
run
(
main_program
)
ref_res
=
np
.
full
([
1
],
10
,
dtype
=
"float32"
)
for
data
in
res
:
np
.
testing
.
assert_allclose
(
data
,
ref_res
,
rtol
=
1e-05
)
ref_res
=
ref_res
+
1
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录