未验证 提交 d83d59dd 编写于 作者: Y Yuang Liu 提交者: GitHub

[cuda graph] partial program with cuda graph under static mode (#43440)

上级 db58dd27
...@@ -59,9 +59,12 @@ ProgramDesc::ProgramDesc() { ...@@ -59,9 +59,12 @@ ProgramDesc::ProgramDesc() {
ProgramDesc::ProgramDesc(const ProgramDesc &o) { ProgramDesc::ProgramDesc(const ProgramDesc &o) {
desc_ = o.desc_; desc_ = o.desc_;
std::vector<framework::BlockDesc *> old_block_desc;
for (int i = 0; i < desc_.blocks_size(); ++i) { for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = desc_.mutable_blocks(i); auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
// record all block desc's ptr from origin program
old_block_desc.emplace_back(o.blocks_[i].get());
} }
for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) { for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) {
auto all_ops = blocks_[block_id]->AllOps(); auto all_ops = blocks_[block_id]->AllOps();
...@@ -70,9 +73,21 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ...@@ -70,9 +73,21 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
for (const std::string &attr_name : op->AttrNames()) { for (const std::string &attr_name : op->AttrNames()) {
if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
framework::BlockDesc *block_desc =
BOOST_GET_CONST(framework::BlockDesc *, op->GetAttr(attr_name));
if (std::find(old_block_desc.begin(), old_block_desc.end(),
block_desc) != old_block_desc.end()) {
// The block is owned by the origin program. Just use id to get
// the corresponding block.
int sub_block_id = int sub_block_id =
o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name); o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name);
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
} else {
// The block is not owned by the origin program. Should copy
// the real block desc instead of logical block in the program.
VLOG(3) << "Set op's block attr with the original block";
op->SetBlockAttr(attr_name, block_desc);
}
} else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) { } else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) {
std::vector<int> sub_block_ids = std::vector<int> sub_block_ids =
o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name); o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name);
......
...@@ -257,7 +257,12 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -257,7 +257,12 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
auto input_var_names = ctx.InputNames("X"); auto input_var_names = ctx.InputNames("X");
auto output_var_names = ctx.OutputNames("Out"); auto output_var_names = ctx.OutputNames("Out");
auto dout_var_names = ctx.OutputNames("DOut"); std::vector<std::string> dout_var_names;
if (!dout_vars.empty()) {
// DOut is a dispensable out, only get the names when it exists.
// Otherwise, it will throw a NotFound error.
dout_var_names = ctx.OutputNames("DOut");
}
// current program may not hold parameters // current program may not hold parameters
std::vector<std::string> param_names; std::vector<std::string> param_names;
...@@ -272,10 +277,23 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -272,10 +277,23 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
// NOTE(chenweihang): In order not to add new variable type, use vector // NOTE(chenweihang): In order not to add new variable type, use vector
// here. Originally, here can use scope directly. // here. Originally, here can use scope directly.
auto *out_scope_vec = ctx.Output<StepScopeVar>("OutScope"); auto *out_scope_vec = ctx.Output<StepScopeVar>("OutScope");
std::unique_ptr<framework::Scope> inner_scope{nullptr};
if (out_scope_vec->size() == 0) {
// For cuda graph under static mode usage.
// For static mode, we cannot set value of a tensor before any run,
// the OutScope variable passed to the op actually contains nothing.
// Just create a tmp scope to run the program.
PADDLE_ENFORCE_EQ(
use_cuda_graph, true,
platform::errors::InvalidArgument(
"If not provide OutScope then must run under cuda graph mode."));
inner_scope = std::make_unique<framework::Scope>();
} else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_scope_vec->size(), 1, out_scope_vec->size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The OutScope of RunProgramGradOp should only hold one scope.")); "The OutScope of RunProgramGradOp should only hold one scope."));
}
// Step 2. prepare executor and init persistable variables // Step 2. prepare executor and init persistable variables
...@@ -284,9 +302,10 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -284,9 +302,10 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
// Learning. Tensor data in multi-step training should be saved into single // Learning. Tensor data in multi-step training should be saved into single
// scope separately. Otherwise, the gradients can be miscalculated because // scope separately. Otherwise, the gradients can be miscalculated because
// always using the Tensor data of the last step in forward. // always using the Tensor data of the last step in forward.
framework::Scope *global_inner_scope = out_scope_vec->front(); framework::Scope *global_inner_scope =
out_scope_vec->size() == 0 ? inner_scope.get() : out_scope_vec->front();
VLOG(2) << "The number of sub scopes before forward: " VLOG(2) << "The number of sub scopes before forward: "
<< out_scope_vec->front()->kids().size(); << global_inner_scope->kids().size();
framework::Scope &scope = global_inner_scope->NewScope(); framework::Scope &scope = global_inner_scope->NewScope();
// share input_vars & parameters into scope // share input_vars & parameters into scope
...@@ -341,13 +360,19 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -341,13 +360,19 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
&scope); &scope);
// Debug info: scope info when run end // Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front()); framework::Scope *target_scope{nullptr};
if (out_scope_vec->size() == 0) {
target_scope = inner_scope.get();
} else {
target_scope = out_scope_vec->front();
}
VLOG(3) << framework::GenScopeTreeDebugInfo(target_scope);
// Step 5. Drop all children scopes while testing. // Step 5. Drop all children scopes while testing.
if (is_test) { if (is_test) {
out_scope_vec->front()->DropKids(); target_scope->DropKids();
} }
VLOG(2) << "The number of sub scopes after forward: " VLOG(2) << "The number of sub scopes after forward: "
<< out_scope_vec->front()->kids().size(); << target_scope->kids().size();
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) platform::DontClearMKLDNNCache(ctx.GetPlace()); if (FLAGS_use_mkldnn) platform::DontClearMKLDNNCache(ctx.GetPlace());
#endif #endif
......
...@@ -28,6 +28,16 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place, ...@@ -28,6 +28,16 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place,
auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
dev_ctx->cudnn_workspace_handle().ResetWorkspace(); dev_ctx->cudnn_workspace_handle().ResetWorkspace();
// After PR(#43206), cudnn related initializations will change to lazy mode.
// It will only be initialized when op calls them. But cuda graph not support
// capture such kind of init, need to init all these handle before cuda graph.
dev_ctx->cublas_handle();
#if CUDA_VERSION >= 11060
dev_ctx->cublaslt_handle();
#endif
dev_ctx->cudnn_handle();
dev_ctx->cusolver_dn_handle();
auto stream = dev_ctx->stream(); auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode); CUDAGraph::BeginCapture(place, stream, mode);
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
import os import os
import paddle import paddle
from paddle.fluid import core
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace
import warnings
if is_compiled_with_cuda() and not is_compiled_with_rocm(): if is_compiled_with_cuda() and not is_compiled_with_rocm():
from paddle.fluid.core import CUDAGraph as CoreCUDAGraph from paddle.fluid.core import CUDAGraph as CoreCUDAGraph
...@@ -106,3 +109,335 @@ def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"): ...@@ -106,3 +109,335 @@ def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"):
else: else:
mock_func._cuda_graph_pool_id = memory_pool._cuda_graph_pool_id mock_func._cuda_graph_pool_id = memory_pool._cuda_graph_pool_id
return new_function return new_function
def copy_var_desc(dst, src):
"""
copy var desc from src to dst
:param dst: framework.VarDesc(cpp), dst var desc, cpp VarDesc instance
:param src: framework.VarDesc(cpp), src var desc, cpp VarDesc instance
:return: no return
"""
dst.set_shape(src.shape)
dst.set_dtype(src.dtype)
dst.set_lod_level(src.lod_level)
dst.set_type(src.type)
dst.set_persistable(src.persistable)
dst.set_is_parameter(src.is_parameter)
dst.set_stop_gradient(src.stop_gradient)
def all_inputs_of_later_op(block, begin_idx):
"""
find all inputs of ops after an idx, used to determine the logical output of a cuda graph section
:param block: framework.Block, the original block
:param begin_idx: int, from which idx (not include) to find the later ins
:return: a list of inputs names for all ops behind begin_idx
"""
ins = []
for idx, op in enumerate(block.ops):
if idx <= begin_idx:
continue
for in_name in op.input_arg_names:
ins.append(in_name)
return list(set(ins))
def construct_program_and_find_ins_outs(section, origin_program, section_idx):
"""
1. Construct a new program for corresponding section
2. Find all the logical inputs and outputs of a program section
:param section: list, one cuda graph section, list of ops
:param origin_program: framework.Program, origin program
:param section_idx: list, the section ops' idx corresponding to the cuda graph section, a list of idx
:return: a new program for the cuda graph section
the logical ins and outs of the cuda graph section
"""
program = paddle.static.Program()
block = program.global_block()
origin_block = origin_program.global_block()
ins = []
outs = []
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
later_ins = all_inputs_of_later_op(origin_block, section_idx[-1])
for op in section:
for in_name in op.input_arg_names:
var = origin_block.var(in_name)
new_var_desc = block.desc.var(var.name.encode("ascii"))
copy_var_desc(new_var_desc, var)
if outs.count(in_name) == 0 and ins.count(in_name) == 0:
# This in var is generated from op outside this section
# Only record once for same input
ins.append(in_name)
elif later_ins.count(in_name) == 0:
# this is var is generated from op inside this section, and only will be used inside this section
outs.remove(in_name)
for out_name in op.output_arg_names:
var = origin_block.var(out_name)
new_var_desc = block.desc.var(var.name.encode("ascii"))
copy_var_desc(new_var_desc, var)
# for every output, we add it to the section's outs
if outs.count(out_name) == 0:
# Only record one out var even if it will be generated by multi ops.
# For scenario like this:
# A = op1(a)
# A = op2(b)
# B = op3(A)
outs.append(out_name)
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op_desc._set_attr(op_role_attr_name, op.attr(op_role_attr_name))
program._sync_with_cpp()
return program, [ins, outs]
def get_cuda_graph_sections(program):
"""
get all sections that should run under cuda graph and the corresponding idx
:param program: framework.Program, the original program
:return: A list of cuda graph sections and the corresponding ops' idx in the block.
The program is under is test or not.
"""
block = program.global_block()
cuda_graph_sections = [] # record all ops in every cuda graph sections
sections_idx = [] # idx of all ops in every cuda graph sections
is_test = False # will be set to True is any op's 'is_test' attr is True
# ops and it's idx between cuda graph wrapped op, may belong to a section
internal_section = []
internal_idx = []
current_section = [] # current recording cuda graph sections
current_idx = [] # current recording cuda graph ops' idx
current_cuda_graph_id = -1 # current recording cuda graph id
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
loss_op_role = int(core.op_proto_and_checker_maker.OpRole.Loss)
backward_op_role = int(core.op_proto_and_checker_maker.OpRole.Backward)
loss_grad_op_role = loss_op_role | backward_op_role
for idx, op in enumerate(block.ops):
if op.type == 'conditional_block' or op.type == 'while':
assert op._cuda_graph_attr is None, "Cuda graph not support conditional block op and while op."
if op.has_attr('is_test') and op.attr('is_test'):
is_test = True
# find cuda graph sections
if op._cuda_graph_attr is not None:
assert isinstance(op._cuda_graph_attr,
str), "cuda_graph_attr should be a str"
cuda_graph_attrs = op._cuda_graph_attr.split(';')
assert len(cuda_graph_attrs) == 3, "cuda graph attr should have three fields: " \
"cuda graph mode, cuda graph memory pool id, cuda graph id"
local_cuda_graph_id = int(cuda_graph_attrs[2])
if local_cuda_graph_id == current_cuda_graph_id:
if len(internal_section) > 0:
assert len(internal_section) == len(
internal_idx
), "len of internal section should be equal with len of internal idx"
for internal_op in internal_section:
loss_related = (int(internal_op.attr(op_role_attr_name))
== loss_op_role) or int(
(internal_op.attr(op_role_attr_name)
) == loss_grad_op_role)
sub_block_related = (op.type == 'conditional_block'
or op.type == 'while')
if loss_related or sub_block_related:
# if loss_related is True
# The internal section contains loss related ops,
# although these ops are between two cuda graph sections with same graph id,
# they belong to none of these two sections.
# The loss related op should be wrapped by user explicitly.
# if sub_block_related is True
# The internal section contains while op or conditional block op.
# These two ops are not supported by cuda graph. Won't extend the section.
internal_section = []
internal_idx = []
# Beside clear the internal section, a new cuda graph section should be recorded
assert len(current_section) == len(current_idx), \
"num of section's op is not equal with the idx"
if len(current_section) > 0:
# store previous section
cuda_graph_sections.append(current_section)
sections_idx.append(current_idx)
current_section = []
current_idx = []
break
# some ops inserted by some optimizer, should be added to current section
for i in range(len(internal_section)):
current_section.append(internal_section[i])
current_idx.append(internal_idx[i])
internal_section = []
current_section.append(op)
current_idx.append(idx)
else:
# current graph id is different with previous, start a new section of cuda graph
# internal ops and idx belong to no section, just clear it
internal_section = []
internal_idx = []
current_cuda_graph_id = local_cuda_graph_id # start record a new section
assert len(current_section) == len(
current_idx
), "num of section's op is not equal with num of idx"
if len(current_section) > 0:
# store previous section
cuda_graph_sections.append(current_section)
sections_idx.append(current_idx)
current_section = [op]
current_idx = [idx]
else:
# recode ops which cuda_graph_attr is None, may belong to a section
internal_section.append(op)
internal_idx.append(idx)
# handle the last section
assert len(current_section) == len(
current_idx), "num of section's op is not equal with num of idx"
if len(current_section) > 0:
# store previous section
cuda_graph_sections.append(current_section)
sections_idx.append(current_idx)
return cuda_graph_sections, sections_idx, is_test
def replace_cuda_graph_section(ins_and_outs, section_program, section_idx,
origin_program, cuda_graph_section, order,
is_test):
"""
Use section_program and ins_and_outs to initialize a run_program_op,
and replace the section_idx marks ops in the origin program.
:param ins_and_outs: list, the logical ins and outs of the section program
:param section_program: framework.Program, the partial program need to run under cuda graph
:param section_idx: list, the idx need to be removed from origin program
:param origin_program: framework.Program, the origin program
:param cuda_graph_section: list, the ops in current sections, used to get the mode, memory pool id and is_test
:param order: int, the order of current section, used to create unique cuda graph var
:param is_test: bool, the program is running under is_test or not
:return: no return
"""
ins = ins_and_outs[0]
outs = ins_and_outs[1]
insert_idx = section_idx[0]
origin_block = origin_program.global_block()
for idx in reversed(section_idx):
# remove all cuda graph marked ops from origin block
origin_block._remove_op(idx, sync=False)
mode = None
memory_pool_id = None
for op in cuda_graph_section:
# find the cuda graph mode and memory pool id, determine is test or not
if op._cuda_graph_attr is not None:
attrs = op._cuda_graph_attr.split(';')
mode = attrs[0]
memory_pool_id = int(attrs[1])
break
assert mode is not None and memory_pool_id is not None, \
"mode and memory pool id should be specified in cuda graph attr"
cuda_graph_var = origin_block.create_var(
name="cuda_graph_" + str(order),
type=core.VarDesc.VarType.RAW,
persistable=True,
stop_gradient=True,
)
# not used for the run_program_op, just needed by the op, but won't be used
out_scope_var = origin_block.create_var(
name="program_out_scope_" + str(order),
type=core.VarDesc.VarType.STEP_SCOPES,
persistable=True,
stop_gradient=True,
)
program_id = _hash_with_id(section_program, ins_and_outs)
# insert the run_program_op into the block
origin_block._insert_op(insert_idx,
type='run_program',
inputs={'X': ins},
outputs={
'Out': outs,
'OutScope': out_scope_var,
'CUDAGraph': cuda_graph_var
},
attrs={
'global_block':
section_program.global_block(),
'start_op_index':
0,
'end_op_index':
len(section_program.global_block().ops),
'is_test':
is_test,
'program_id':
program_id,
'cuda_graph_capture_mode':
mode,
'cuda_graph_pool_id':
memory_pool_id,
})
def cuda_graph_transform(program):
"""
replace the ops marked with cuda_graph_attr to run_program_op to use cuda graph
:param program: framework.Program, the program to be transformed
:return: the cuda graph section program, user should hold these programs!
"""
if len(program.blocks) > 1:
# some sub blocks may be inserted by optimizer but will not use during training, just warn here
warnings.warn(
"Sub block(s) has been detected in the program. "
"Cuda graph not support op with sub block, and it will only handle the global block."
)
# step 1: get all cuda graph sections.
# A cuda graph section contains all ops marked with same cuda graph id and
# some ops inserted by some optimizers (amp, sharding for example) between ops with same id.
cuda_graph_sections, sections_idx, is_test = get_cuda_graph_sections(
program)
assert len(cuda_graph_sections) == len(sections_idx), \
"num of cuda graph sections is not equal with num of idx sections"
# step 2: construct new program for each section and find inputs and outputs of each section.
# The inputs are variables generated outside the section but will be used by this section.
# The outputs are variables generated by this section and will be used after the end of the section.
ins_and_outs = []
section_programs = []
for i in range(len(cuda_graph_sections)):
# creating new program for current section
section_program, ins_outs = construct_program_and_find_ins_outs(
cuda_graph_sections[i], program, sections_idx[i])
ins_and_outs.append(ins_outs)
section_programs.append(section_program)
assert len(section_programs) == len(cuda_graph_sections), \
"the num of cuda graph sections should be equal with the num of new program"
# step 3: replace the ops in original program with run_program_op.
# Will remove all ops in the section from origin program, and use run_program_op to replace them.
for i in reversed(range(len(cuda_graph_sections))):
# carry out the replacement in reversed order, to keep the previous idx intact
replace_cuda_graph_section(ins_and_outs[i],
section_programs[i],
sections_idx[i],
program,
cuda_graph_sections[i],
order=i,
is_test=is_test)
# NOTE: user should hold these program, for now just return these program back to caller
return section_programs
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import unittest
import numpy as np
from paddle.device.cuda.graphs import wrap_cuda_graph, is_cuda_graph_supported, cuda_graph_transform
paddle.enable_static()
class SimpleModel(nn.Layer):
def __init__(self, in_size, out_size):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(in_size, out_size)
self.dropout_1 = paddle.nn.Dropout(0.1)
self.relu = nn.ReLU()
self.dropout_2 = paddle.nn.Dropout(0.5)
self.gelu = nn.GELU()
def forward(self, x):
x = self.linear(x)
x = self.dropout_1(x)
x = self.relu(x)
x = self.dropout_2(x)
x = self.gelu(x)
return x
class TestCudaGraphAttrAll(unittest.TestCase):
def setUp(self):
paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 0.0})
def get_model(self, use_cuda_graph=False):
x = paddle.static.data(shape=[3, 10], dtype='float32', name='x')
model_start = SimpleModel(10, 20)
if use_cuda_graph:
model_start = wrap_cuda_graph(model_start)
model_inter = SimpleModel(20, 20)
model_end = SimpleModel(20, 10)
if use_cuda_graph:
model_end = wrap_cuda_graph(model_end, memory_pool='new')
start_out = model_start(x)
inter_out = model_inter(start_out)
end_out = model_end(inter_out)
loss = paddle.mean(end_out)
opt = paddle.optimizer.SGD()
opt.minimize(loss)
return loss
def run_with_cuda_graph(self, x_data):
# run with cuda graph
paddle.seed(1024)
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
loss = self.get_model(use_cuda_graph=True)
section_programs = cuda_graph_transform(main_prog)
assert len(section_programs) == 4
block = main_prog.global_block()
run_program_op_num = 0
for op in block.ops:
if op.type == 'run_program':
run_program_op_num += 1
assert run_program_op_num == 4
exe = paddle.static.Executor(paddle.CUDAPlace(0))
exe.run(start_prog)
for i in range(10):
rst = exe.run(main_prog, feed={'x': x_data}, fetch_list=[loss])
return rst
def normal_run(self, x_data):
# run without cuda graph
paddle.seed(1024)
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
loss = self.get_model()
exe = paddle.static.Executor(paddle.CUDAPlace(0))
exe.run(start_prog)
for i in range(10):
rst = exe.run(main_prog, feed={'x': x_data}, fetch_list=[loss])
return rst
def test_static_mode_cuda_graph(self):
if not is_cuda_graph_supported():
return
x_data = np.random.random((3, 10)).astype('float32')
cuda_graph_rst = self.run_with_cuda_graph(x_data)
normal_run_rst = self.normal_run(x_data)
assert np.array_equal(cuda_graph_rst, normal_run_rst)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册