未验证 提交 146e942c 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #15250 from tensor-tang/refine/seqpool/feed

Refine/seqpool/feed with infer zerocopytensor
......@@ -69,6 +69,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
......
......@@ -39,21 +39,25 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=](
Node* x, const std::string& type, int idx) -> bool {
bool ok = x && x->IsOp() && x->Op()->Type() == "sequence_pool" &&
x->Op()->HasAttr("pooltype") &&
boost::get<std::string>(x->Op()->GetAttr("pooltype")) == type &&
x->outputs.size() == 2; // seqpool should only have 2 outputs
if (ok) {
// only one output of seqpool_op is nth_input_var of concat
// the other one should be unused empty var
bool this_is_seqpool_op =
x && x->IsOp() && x->Op()->Type() == "sequence_pool" &&
x->Op()->HasAttr("pooltype") &&
boost::get<std::string>(x->Op()->GetAttr("pooltype")) == type &&
x->outputs.size() == 2; // seqpool should only have 2 outputs
bool satisfied_all = this_is_seqpool_op;
if (this_is_seqpool_op) {
// Only one output of seqpool_op is nth_input_var of concat,
// the other one should be unused empty var.
if (is_nth_input_var_of_concat(x->outputs[0], idx)) {
ok = ok && x->outputs[1]->IsVar() && x->outputs[1]->outputs.size() == 0;
satisfied_all = satisfied_all && x->outputs[1]->IsVar() &&
x->outputs[1]->outputs.size() == 0;
} else {
ok = ok && is_nth_input_var_of_concat(x->outputs[1], idx) &&
x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0;
satisfied_all =
satisfied_all && is_nth_input_var_of_concat(x->outputs[1], idx) &&
x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0;
}
}
return ok;
return satisfied_all;
};
auto* concat_op = pattern->NewNode(
......@@ -72,6 +76,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
std::vector<PDNode*> seqpool_ops_input_var(num_inputs);
std::vector<PDNode*> seqpool_ops_output_var(num_inputs);
std::vector<PDNode*> seqpool_ops_output_unused_var(num_inputs);
std::vector<PDNode*> seqpool_ops(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
......@@ -84,6 +89,15 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
},
name_scope + "/sequence_pool_out_" + std::to_string(i));
seqpool_ops_output_unused_var[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && x->inputs.size() == 1 &&
x->outputs.size() == 0 &&
is_seqpool_op_with_pootype_of_nth_input_of_concat(x->inputs[0],
"SUM", i);
},
name_scope + "/sequence_pool_unused_out_" + std::to_string(i));
seqpool_ops[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() &&
......@@ -93,23 +107,29 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
seqpool_ops_input_var[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && x->outputs.size() >= 1 &&
is_seqpool_op_with_pootype_of_nth_input_of_concat(
x->outputs[0], "SUM", i);
bool basic = x && x->IsVar() && x->outputs.size() >= 1;
bool next_is_fine = false;
for (auto* o : x->outputs) {
if (is_seqpool_op_with_pootype_of_nth_input_of_concat(o, "SUM",
i)) {
next_is_fine = true;
break;
}
}
return basic && next_is_fine;
},
name_scope + "/sequence_pool_in_" + std::to_string(i));
// Links
seqpool_ops[i]
->LinksFrom({seqpool_ops_input_var[i]})
.LinksTo({seqpool_ops_output_var[i]});
.LinksTo({seqpool_ops_output_var[i], seqpool_ops_output_unused_var[i]});
}
concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var});
return concat_out_var;
}
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
int num_inputs) {
int BuildFusion(Graph* graph, const std::string& name_scope, int num_inputs) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs);
......@@ -178,8 +198,8 @@ std::unique_ptr<ir::Graph> SeqPoolConcatFusePass::ApplyImpl(
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = 0;
for (int i = MAX_CONCAT_INPUTS; i > 0; --i) {
fusion_count += BuildFusion(
graph.get(), name_scope_ + "/" + std::to_string(i), param_scope(), i);
fusion_count +=
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i);
}
AddStatis(fusion_count);
......
......@@ -23,6 +23,20 @@ namespace paddle {
namespace framework {
namespace ir {
/**
* Fuse SequencePool(with sum pooltype yet) and Concat;
*
* Before fuse:
* | | |
* seq_pool, seq_pool, ... seq_pool
* \ | ... /
* concat
* |
* After fuse:
* \ | /
* FusionSeqPoolConcat
* |
*/
class SeqPoolConcatFusePass : public FusePassBase {
public:
virtual ~SeqPoolConcatFusePass() {}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "sequence_pool") {
op->SetInput("X", {inputs[0]});
std::string pooltype = "SUM";
op->SetAttr("pooltype", pooltype);
op->SetOutput("MaxIndex", {outputs[0]});
op->SetOutput("Out", {outputs[1]});
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetAttr("axis", 1);
op->SetOutput("Out", {outputs[0]});
} else {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
int CountOpType(const ir::Graph* graph,
const std::string& op_type = "fusion_seqpool_concat") {
int count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == op_type) {
++count;
}
}
return count;
}
std::unique_ptr<ir::Graph> GetNumNodesOfBeforeAfter(
std::unique_ptr<ir::Graph> graph, int* before, int* after,
const std::string& pass_type = "seqpool_concat_fuse_pass") {
auto pass = PassRegistry::Instance().Get(pass_type);
*before = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
*after = graph->Nodes().size();
return graph;
}
/*
* Before fuse:
* a b c
* | | |
* op1 op2 op3
* / \ / \ / \
* d e f g h i
* \ | /
* concat
* |
* j
* Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr
*
* After fuse:
* a b c
* \ | /
* fusion_seqpool_concat
* |
* j
*/
TEST(SeqPoolConcatFusePass, basic) {
ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
}
SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
std::vector<std::string>({"d", "e"}));
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
std::vector<std::string>({"f", "g"}));
SetOp(&prog, "sequence_pool", std::vector<std::string>({"c"}),
std::vector<std::string>({"h", "i"}));
SetOp(&prog, "concat", std::vector<std::string>({"e", "g", "i"}),
std::vector<std::string>({"j"}));
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove 10 Nodes: op1, op2, op3, d, e, f, g, h, i, concat_op
// Add 1 Node: fusion_seqpool_concat
EXPECT_EQ(after, before - 9);
EXPECT_EQ(CountOpType(graph.get()), 1);
}
/*
* Before fuse:
* a b
* | / \
* op1 op2 op3
* / \ / \ \
* c d e f g
* \ /
* concat
* |
* h
* Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr
*
* After fuse:
* a b
* \ / \
* fusion_seqpool_concat op3
* | |
* h g
*/
TEST(SeqPoolConcatFusePass, advanced) {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "d", "e", "f", "g", "h"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
}
SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
std::vector<std::string>({"c", "d"}));
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
std::vector<std::string>({"e", "f"}));
SetOp(&prog, "op3", std::vector<std::string>({"b"}),
std::vector<std::string>({"g"}));
SetOp(&prog, "concat", std::vector<std::string>({"d", "f"}),
std::vector<std::string>({"h"}));
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove 7 Nodes: op1, op2, c, d, e, f concat_op
// Add 1 Node: fusion_seqpool_concat
EXPECT_EQ(after, before - 6);
EXPECT_EQ(CountOpType(graph.get()), 1);
}
ProgramDesc BuildProgramDesc(int num_inputs_of_concat) {
ProgramDesc prog;
auto new_var = [&](const std::string& name) {
auto* var = prog.MutableBlock(0)->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
};
std::vector<std::string> concat_inputs;
for (int i = 0; i < num_inputs_of_concat; ++i) {
std::string prefix = "seqpool_op_" + i;
new_var(prefix + "in");
new_var(prefix + "out");
new_var(prefix + "out_unused");
SetOp(&prog, "sequence_pool", std::vector<std::string>({prefix + "in"}),
std::vector<std::string>({prefix + "out", prefix + "out_unused"}));
concat_inputs.push_back(prefix + "out");
}
SetOp(&prog, "concat", concat_inputs,
std::vector<std::string>({"concat_out"}));
return prog;
}
// test more inputs of concat
TEST(SeqPoolConcatFusePass, more_inputs) {
for (int num : {1, 2, 10}) {
ProgramDesc prog = BuildProgramDesc(num);
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove Nodes: n * (seqpool_op, out, out_unused), and concat_op
// Add Node: fusion_seqpool_concat op
EXPECT_EQ(after, before - num * 3);
EXPECT_EQ(CountOpType(graph.get()), 1);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(seqpool_concat_fuse_pass);
......@@ -204,11 +204,14 @@ static std::string DescribeTensor(const PaddleTensor &tensor) {
os << to_string(l) << "; ";
}
os << "\n";
os << " - data: ";
os << " - memory length: " << tensor.data.length();
os << "\n";
os << " - data: ";
int dim = VecReduceToInt(tensor.shape);
float *pdata = static_cast<float *>(tensor.data.data());
for (int i = 0; i < dim; i++) {
os << static_cast<float *>(tensor.data.data())[i] << " ";
os << pdata[i] << " ";
}
os << '\n';
return os.str();
......@@ -224,10 +227,12 @@ static std::string DescribeZeroCopyTensor(const ZeroCopyTensor &tensor) {
os << to_string(l) << "; ";
}
os << "\n";
os << " - data: ";
PaddlePlace place;
int size;
const auto *data = tensor.data<float>(&place, &size);
os << " - numel: " << size;
os << "\n";
os << " - data: ";
for (int i = 0; i < size; i++) {
os << data[i] << " ";
}
......
......@@ -123,7 +123,8 @@ class ZeroCopyTensor {
*/
template <typename T>
T* mutable_data(PaddlePlace place);
/** Get the memory directly, will return the place and memory size by pointer.
/** Get the memory directly, will return the place and element size by
* pointer.
* This is for reading the output tensor.
*/
template <typename T>
......
......@@ -351,10 +351,10 @@ TEST(Analyzer_rnn1, ZeroCopy) {
ASSERT_TRUE(native_predictor->Run(native_inputs.front(), &native_outputs));
LOG(INFO) << "native output " << DescribeTensor(native_outputs.front());
int output_size{0};
int output_size{0}; // this is the number of elements not memory size
auto *zero_copy_data = output_tensor->data<float>(&place, &output_size);
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (size_t i = 0; i < output_size / sizeof(float); i++) {
for (int i = 0; i < output_size; i++) {
EXPECT_NEAR(zero_copy_data[i], native_data[i], 1e-3);
}
}
......
......@@ -121,14 +121,6 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data) {
}
}
void SetConfig(AnalysisConfig *cfg) {
cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params");
cfg->DisableGpu();
cfg->SwitchSpecifyInputNames();
cfg->pass_builder()->TurnOnDebug();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
std::vector<PaddleTensor> input_slots;
......@@ -141,15 +133,22 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
}
}
void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params");
cfg->DisableGpu();
cfg->SwitchSpecifyInputNames();
cfg->pass_builder()->TurnOnDebug();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
if (use_mkldnn) {
cfg->EnableMKLDNN();
}
}
void profile(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
SetConfig(&cfg, use_mkldnn);
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
......@@ -169,22 +168,112 @@ TEST(Analyzer_seq_pool1, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Check the fuse status
TEST(Analyzer_seq_pool1, fuse_statis) {
// Compare Deterministic result
TEST(Analyzer_seq_pool1, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
void analysis_fuse_statis(bool use_zerocopy) {
AnalysisConfig cfg;
SetConfig(&cfg);
cfg.SwitchUseFeedFetchOps(!use_zerocopy);
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 195);
}
// Check the fuse status
TEST(Analyzer_seq_pool1, fuse_statis) { analysis_fuse_statis(false); }
void PrepareZeroCopyInputs(
const std::unique_ptr<PaddlePredictor> &predictor,
std::vector<std::unique_ptr<ZeroCopyTensor>> *inputs) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
// only feed one batch
const auto &one_batch = data.NextBatch();
inputs->clear();
for (size_t i = 0; i < one_batch.size(); ++i) {
auto &slot = one_batch[i];
auto tensor = predictor->GetInputTensor(slot.name + "_embed");
tensor->Reshape(slot.shape);
tensor->SetLoD({slot.lod});
ZeroCopyTensorAssignData<float>(tensor.get(), slot.data);
inputs->emplace_back(std::move(tensor));
}
}
// return the output values
std::vector<float> zerocopy_profile(int repeat_times) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
std::vector<std::unique_ptr<ZeroCopyTensor>> inputs;
PrepareZeroCopyInputs(predictor, &inputs);
auto output_tensor = predictor->GetOutputTensor("reduce_sum_0.tmp_0");
Timer timer;
LOG(INFO) << "Warm up run...";
timer.tic();
predictor->ZeroCopyRun();
PrintTime(FLAGS_batch_size, 1, 1, 0, timer.toc(), 1);
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
}
LOG(INFO) << "Run " << repeat_times << " times...";
timer.tic();
for (int i = 0; i < repeat_times; i++) {
predictor->ZeroCopyRun();
}
PrintTime(FLAGS_batch_size, repeat_times, 1, 0, timer.toc() / repeat_times,
1);
VLOG(3) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor);
PaddlePlace place;
int output_size{0};
auto *pdata = output_tensor->data<float>(&place, &output_size);
std::vector<float> res(output_size);
for (int i = 0; i < output_size; ++i) {
res[i] = pdata[i];
}
return res;
}
TEST(Analyzer_seq_pool1, zerocopy_profile) { zerocopy_profile(FLAGS_repeat); }
TEST(Analyzer_seq_pool1, zerocopy_fuse_statis) { analysis_fuse_statis(true); }
TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(true);
auto predictor = CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
std::vector<PaddleTensor> native_outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
ASSERT_TRUE(predictor->Run(input_slots_all[0], &native_outputs));
EXPECT_EQ(native_outputs.size(), 1UL);
auto zerocopy_output = zerocopy_profile(1);
EXPECT_EQ(zerocopy_output.size() * sizeof(float),
native_outputs.front().data.length());
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (size_t i = 0; i < zerocopy_output.size(); ++i) {
EXPECT_NEAR(zerocopy_output[i], native_data[i], 1e-3);
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -23,7 +23,7 @@ namespace operators {
void FusionSeqPoolConcatOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqPoolConcatOp should be empty.");
"Inputs(X) of FusionSeqPoolConcatOp should not be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSeqPoolConcatOp should not be null.");
int axis = ctx->Attrs().Get<int>("axis");
......@@ -54,12 +54,13 @@ void FusionSeqPoolConcatOpMaker::Make() {
AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
AddAttr<std::string>("pooltype",
"(string, default 'AVERAGE') some of the pooling "
"(string, default 'SUM') some of the pooling "
"pooltype of SequencePoolOp.")
.SetDefault("SUM")
.InEnum({"AVERAGE", "SUM", "SQRT"});
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.")
"The axis along which the input tensors will be concatenated. "
"Only supports concat axis=1 yet.")
.SetDefault(1);
AddComment(R"DOC(
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
......@@ -100,6 +101,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
size_t n = ins.size();
size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) {
auto x_dims = ins[i]->dims();
auto x_lod = ins[i]->lod()[0];
......@@ -112,7 +114,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
for (size_t j = 0; j < bs; ++j) {
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
seqpool(src, dst, &attr);
dst += n * w;
dst += dst_step_size;
src += attr.h * attr.w;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册