未验证 提交 8b97903d 编写于 作者: Z zhxfl 提交者: GitHub

Merge pull request #616 from zhxfl/fix-610

first version of data load
"""This model read the sample from disk.
use multiprocessing to reading samples
push samples from one block to multiprocessing queue
Todos:
1. multiprocess read block from disk
"""
import random
import Queue
import numpy as np
import struct
import data_utils.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.trans_add_delta as trans_add_delta
class OneBlock(object):
""" struct for one block :
contain label, label desc, feature, feature_desc
Attributes:
label(str) : label path of one block
label_desc(str) : label description path of one block
feature(str) : feature path of on block
feature_desc(str) : feature description path of on block
"""
def __init__(self):
"""the constructor."""
self.label = "label"
self.label_desc = "label_desc"
self.feature = "feature"
self.feature_desc = "feature_desc"
class DataRead(object):
"""
Attributes:
_lblock(obj:`OneBlock`) : the list of OneBlock
_ndrop_sentence_len(int): dropout the sentence which's frame_num large than _ndrop_sentence_len
_que_sample(obj:`Queue`): sample buffer
_nframe_dim(int): the batch sample frame_dim(todo remove)
_nstart_block_idx(int): the start block id
_nload_block_num(int): the block num
"""
def __init__(self, sfeature_lst, slabel_lst, ndrop_sentence_len=512):
"""
Args:
sfeature_lst(str):feature lst path
slabel_lst(str):label lst path
Returns:
None
"""
self._lblock = []
self._ndrop_sentence_len = ndrop_sentence_len
self._que_sample = Queue.Queue()
self._nframe_dim = 120 * 11
self._nstart_block_idx = 0
self._nload_block_num = 1
self._ndrop_frame_len = 256
self._load_list(sfeature_lst, slabel_lst)
def _load_list(self, sfeature_lst, slabel_lst):
""" load list and shuffle
Args:
sfeature_lst(str):feature lst path
slabel_lst(str):label lst path
Returns:
None
"""
lfeature = open(sfeature_lst).readlines()
llabel = open(slabel_lst).readlines()
assert len(llabel) == len(lfeature)
for i in range(0, len(lfeature), 2):
one_block = OneBlock()
one_block.label = llabel[i]
one_block.label_desc = llabel[i + 1]
one_block.feature = lfeature[i]
one_block.feature_desc = lfeature[i + 1]
self._lblock.append(one_block)
random.shuffle(self._lblock)
def _load_one_block(self, lsample, id):
"""read one block by id and push load sample in list lsample
Args:
lsample(list): return sample list
id(int): block id
Returns:
None
"""
if id >= len(self._lblock):
return
slabel_path = self._lblock[id].label.strip()
slabel_desc_path = self._lblock[id].label_desc.strip()
sfeature_path = self._lblock[id].feature.strip()
sfeature_desc_path = self._lblock[id].feature_desc.strip()
llabel_line = open(slabel_desc_path).readlines()
lfeature_line = open(sfeature_desc_path).readlines()
file_lable_bin = open(slabel_path, "r")
file_feature_bin = open(sfeature_path, "r")
sample_num = int(llabel_line[0].split()[1])
assert sample_num == int(lfeature_line[0].split()[1])
llabel_line = llabel_line[1:]
lfeature_line = lfeature_line[1:]
for i in range(sample_num):
# read label
llabel_split = llabel_line[i].split()
nlabel_start = int(llabel_split[2])
nlabel_size = int(llabel_split[3])
nlabel_frame_num = int(llabel_split[4])
file_lable_bin.seek(nlabel_start, 0)
label_bytes = file_lable_bin.read(nlabel_size)
assert nlabel_frame_num * 4 == len(label_bytes)
label_array = struct.unpack('I' * nlabel_frame_num, label_bytes)
label_data = np.array(label_array, dtype="int64")
label_data = label_data.reshape((nlabel_frame_num, 1))
# read feature
lfeature_split = lfeature_line[i].split()
nfeature_start = int(lfeature_split[2])
nfeature_size = int(lfeature_split[3])
nfeature_frame_num = int(lfeature_split[4])
nfeature_frame_dim = int(lfeature_split[5])
file_feature_bin.seek(nfeature_start, 0)
feature_bytes = file_feature_bin.read(nfeature_size)
assert nfeature_frame_num * nfeature_frame_dim * 4 == len(
feature_bytes)
feature_array = struct.unpack('f' * nfeature_frame_num *
nfeature_frame_dim, feature_bytes)
feature_data = np.array(feature_array, dtype="float32")
feature_data = feature_data.reshape(
(nfeature_frame_num, nfeature_frame_dim))
#drop long sentence
if self._ndrop_frame_len < feature_data.shape[0]:
continue
lsample.append((feature_data, label_data))
def get_one_batch(self, nbatch_size):
"""construct one batch(feature, label), batch size is nbatch_size
Args:
nbatch_size(int): batch size
Returns:
None
"""
if self._que_sample.empty():
lsample = self._load_block(
range(self._nstart_block_idx, self._nstart_block_idx +
self._nload_block_num, 1))
self._move_sample(lsample)
self._nstart_block_idx += self._nload_block_num
if self._que_sample.empty():
self._nstart_block_idx = 0
return None
#cal all frame num
ncur_len = 0
lod = [0]
samples = []
bat_feature = np.zeros((nbatch_size, self._nframe_dim))
for i in range(nbatch_size):
# empty clear zero
if self._que_sample.empty():
self._nstart_block_idx = 0
# copy
else:
(one_feature, one_label) = self._que_sample.get()
samples.append((one_feature, one_label))
ncur_len += one_feature.shape[0]
lod.append(ncur_len)
bat_feature = np.zeros((ncur_len, self._nframe_dim), dtype="float32")
bat_label = np.zeros((ncur_len, 1), dtype="int64")
ncur_len = 0
for sample in samples:
one_feature = sample[0]
one_label = sample[1]
nframe_num = one_feature.shape[0]
nstart = ncur_len
nend = ncur_len + nframe_num
bat_feature[nstart:nend, :] = one_feature
bat_label[nstart:nend, :] = one_label
ncur_len += nframe_num
return (bat_feature, bat_label, lod)
def set_trans(self, ltrans):
""" set transform list
Args:
ltrans(list): data tranform list
Returns:
None
"""
self._ltrans = ltrans
def _load_block(self, lblock_id):
"""read blocks
"""
lsample = []
for id in lblock_id:
self._load_one_block(lsample, id)
# transform sample
for (nidx, sample) in enumerate(lsample):
for trans in self._ltrans:
sample = trans.perform_trans(sample)
#print nidx
lsample[nidx] = sample
return lsample
def load_block(self, lblock_id):
"""read blocks
Args:
lblock_id(list):the block list id
Returns:
None
"""
lsample = []
for id in lblock_id:
self._load_one_block(lsample, id)
# transform sample
for (nidx, sample) in enumerate(lsample):
for trans in self._ltrans:
sample = trans.perform_trans(sample)
#print nidx
lsample[nidx] = sample
return lsample
def _move_sample(self, lsample):
"""move sample to queue
Args:
lsample(list): one block of samples read from disk
Returns:
None
"""
# random
random.shuffle(lsample)
for sample in lsample:
self._que_sample.put(sample)
import numpy as np
import math
import copy
class TransAddDelta(object):
""" add delta of feature data
trans feature for shape(a, b) to shape(a, b * 3)
Attributes:
_norder(int):
_window(int):
"""
def __init__(self, norder=2, nwindow=2):
""" init construction
Args:
norder: default 2
nwindow: default 2
"""
self._norder = norder
self._nwindow = nwindow
def perform_trans(self, sample):
""" add delta for feature
trans feature shape from (a,b) to (a, b * 3)
Args:
sample(object,tuple): contain feature numpy and label numpy
Returns:
(feature, label)
"""
(feature, label) = sample
frame_dim = feature.shape[1]
d_frame_dim = frame_dim * 3
head_filled = 5
tail_filled = 5
mat = np.zeros(
(feature.shape[0] + head_filled + tail_filled, d_frame_dim),
dtype="float32")
#copy first frame
for i in xrange(head_filled):
np.copyto(mat[i, 0:frame_dim], feature[0, :])
np.copyto(mat[head_filled:head_filled + feature.shape[0], 0:frame_dim],
feature[:, :])
# copy last frame
for i in xrange(head_filled + feature.shape[0], mat.shape[0], 1):
np.copyto(mat[i, 0:frame_dim], feature[feature.shape[0] - 1, :])
nframe = feature.shape[0]
start = head_filled
tmp_shape = mat.shape
mat = mat.reshape((tmp_shape[0] * tmp_shape[1]))
self._regress(mat, start * d_frame_dim, mat,
start * d_frame_dim + frame_dim, frame_dim, nframe,
d_frame_dim)
self._regress(mat, start * d_frame_dim + frame_dim, mat,
start * d_frame_dim + 2 * frame_dim, frame_dim, nframe,
d_frame_dim)
mat.shape = tmp_shape
return (mat[head_filled:mat.shape[0] - tail_filled, :], label)
def _regress(self, data_in, start_in, data_out, start_out, size, n, step):
""" regress
Args:
data_in: in data
start_in: start index of data_in
data_out: out data
start_out: start index of data_out
size: frame dimentional
n: frame num
step: 3 * (frame num)
Returns:
None
"""
sigma_t2 = 0.0
delta_window = self._nwindow
for t in xrange(1, delta_window + 1):
sigma_t2 += t * t
sigma_t2 *= 2.0
for i in xrange(n):
fp1 = start_in
fp2 = start_out
for j in xrange(size):
back = fp1
forw = fp1
sum = 0.0
for t in xrange(1, delta_window + 1):
back -= step
forw += step
sum += t * (data_in[forw] - data_in[back])
data_out[fp2] = sum / sigma_t2
fp1 += 1
fp2 += 1
start_in += step
start_out += step
import numpy as np
import math
class TransMeanVarianceNorm(object):
""" normalization of mean variance for feature data
Attributes:
_mean(numpy.array): the feature mean vector
_var(numpy.array): the feature variance
"""
def __init__(self, snorm_path):
"""init construction
Args:
snorm_path: the path of mean and variance
"""
self._mean = None
self._var = None
self._load_norm(snorm_path)
def _load_norm(self, snorm_path):
""" load mean var file
Args:
snorm_path(str):the file path
"""
lLines = open(snorm_path).readlines()
nLen = len(lLines)
self._mean = np.zeros((nLen), dtype="float32")
self._var = np.zeros((nLen), dtype="float32")
self._nLen = nLen
for nidx, l in enumerate(lLines):
s = l.split()
assert len(s) == 2
self._mean[nidx] = float(s[0])
self._var[nidx] = 1.0 / math.sqrt(float(s[1]))
if self._var[nidx] > 100000.0:
self._var[nidx] = 100000.0
def get_mean_var(self):
""" get mean and var
Args:
Returns:
(mean, var)
"""
return (self._mean, self._var)
def perform_trans(self, sample):
""" feature = (feature - mean) * var
Args:
sample(object):input sample, contain feature numpy and label numpy
Returns:
(feature, label)
"""
(feature, label) = sample
shape = feature.shape
assert len(shape) == 2
nfeature_len = shape[0] * shape[1]
assert nfeature_len % self._nLen == 0
ncur_idx = 0
feature = feature.reshape((nfeature_len))
while ncur_idx < nfeature_len:
block = feature[ncur_idx:ncur_idx + self._nLen]
block = (block - self._mean) * self._var
feature[ncur_idx:ncur_idx + self._nLen] = block
ncur_idx += self._nLen
feature = feature.reshape(shape)
return (feature, label)
import numpy as np
import math
class TransSplice(object):
""" copy feature context to construct new feature
expand feature data from shape (frame_num, frame_dim)
to shape (frame_num, frame_dim * 11)
Attributes:
_nleft_context(int): copy left context number
_nright_context(int): copy right context number
"""
def __init__(self, nleft_context=5, nright_context=5):
""" init construction
Args:
nleft_context(int):
nright_context(int):
"""
self._nleft_context = nleft_context
self._nright_context = nright_context
def perform_trans(self, sample):
""" copy feature context
Args:
sample(object): input sample(feature, label)
Return:
(feature, label)
"""
(feature, label) = sample
nframe_num = feature.shape[0]
nframe_dim = feature.shape[1]
nnew_frame_dim = nframe_dim * (
self._nleft_context + self._nright_context + 1)
mat = np.zeros(
(nframe_num + self._nleft_context + self._nright_context,
nframe_dim),
dtype="float32")
ret = np.zeros((nframe_num, nnew_frame_dim), dtype="float32")
#copy left
for i in xrange(self._nleft_context):
mat[i, :] = feature[0, :]
#copy middle
mat[self._nleft_context:self._nleft_context +
nframe_num, :] = feature[:, :]
#copy right
for i in xrange(self._nright_context):
mat[i + self._nleft_context + nframe_num, :] = feature[-1, :]
mat = mat.reshape(mat.shape[0] * mat.shape[1])
ret = ret.reshape(ret.shape[0] * ret.shape[1])
for i in xrange(nframe_num):
np.copyto(ret[i * nnew_frame_dim:(i + 1) * nnew_frame_dim],
mat[i * nframe_dim:i * nframe_dim + nnew_frame_dim])
ret = ret.reshape((nframe_num, nnew_frame_dim))
return (ret, label)
def to_lodtensor(data, place):
"""convert tensor to lodtensor
"""
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = numpy.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def lodtensor_to_ndarray(lod_tensor):
"""conver lodtensor to ndarray
"""
dims = lod_tensor.get_dims()
ret = np.zeros(shape=dims).astype('float32')
for i in xrange(np.product(dims)):
ret.ravel()[i] = lod_tensor.get_float_element(i)
return ret, lod_tensor.lod()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import argparse
import time
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import paddle.v2.fluid.profiler as profiler
import data_utils.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.trans_add_delta as trans_add_delta
import data_utils.trans_splice as trans_splice
import data_utils.data_read as reader
def parse_args():
parser = argparse.ArgumentParser("LSTM model benchmark.")
parser.add_argument(
'--batch_size',
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
default=5,
help='Number of lstm layers to stack. (default: %(default)d)')
parser.add_argument(
'--proj_dim',
type=int,
default=512,
help='Project size of lstm unit. (default: %(default)d)')
parser.add_argument(
'--hidden_dim',
type=int,
default=1024,
help='Hidden size of lstm unit. (default: %(default)d)')
parser.add_argument(
'--pass_num',
type=int,
default=100,
help='Epoch number to train. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
default=0.002,
help='Learning rate used to train. (default: %(default)f)')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type. (default: %(default)s)')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
parser.add_argument(
'--use_cprof', action='store_true', help='If set, use cProfile.')
parser.add_argument(
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
parser.add_argument('--mean_var', type=str, help='mean var path')
parser.add_argument('--feature_lst', type=str, help='mean var path')
parser.add_argument('--label_lst', type=str, help='mean var path')
args = parser.parse_args()
return args
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def dynamic_lstmp_model(hidden_dim,
proj_dim,
stacked_num,
class_num=1749,
is_train=True):
feature = fluid.layers.data(
name="feature", shape=[-1, 120 * 11], dtype="float32", lod_level=1)
seq_conv1 = fluid.layers.sequence_conv(
input=feature,
num_filters=1024,
filter_size=3,
filter_stride=1,
bias_attr=True)
bn1 = fluid.layers.batch_norm(
input=seq_conv1,
act="sigmoid",
is_test=False,
momentum=0.9,
epsilon=1e-05,
data_layout='NCHW')
stack_input = bn1
for i in range(stacked_num):
fc = fluid.layers.fc(input=stack_input,
size=hidden_dim * 4,
bias_attr=True)
proj, cell = fluid.layers.dynamic_lstmp(
input=fc,
size=hidden_dim * 4,
proj_size=proj_dim,
bias_attr=True,
use_peepholes=True,
is_reverse=False,
cell_activation="tanh",
proj_activation="tanh")
bn = fluid.layers.batch_norm(
input=proj,
act="sigmoid",
is_test=False,
momentum=0.9,
epsilon=1e-05,
data_layout='NCHW')
stack_input = bn
prediction = fluid.layers.fc(input=stack_input,
size=class_num,
act='softmax')
if not is_train: return feature, prediction
label = fluid.layers.data(
name="label", shape=[-1, 1], dtype="int64", lod_level=1)
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return prediction, label, avg_cost
def train(args):
if args.use_cprof:
pr = cProfile.Profile()
pr.enable()
prediction, label, avg_cost = dynamic_lstmp_model(
args.hidden_dim, args.proj_dim, args.stacked_num)
adam_optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
adam_optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=prediction, label=label)
# clone from default main program
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
test_accuracy = fluid.evaluator.Accuracy(input=prediction, label=label)
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
inference_program = fluid.io.get_inference_program(test_target)
place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
trans_splice.TransSplice()
]
data_reader = reader.DataRead(args.feature_lst, args.label_lst)
data_reader.set_trans(ltrans)
res_feature = fluid.LoDTensor()
res_label = fluid.LoDTensor()
for pass_id in xrange(args.pass_num):
pass_start_time = time.time()
words_seen = 0
accuracy.reset(exe)
batch_id = 0
while True:
# load_data
one_batch = data_reader.get_one_batch(args.batch_size)
if one_batch == None:
break
(bat_feature, bat_label, lod) = one_batch
res_feature.set(bat_feature, place)
res_feature.set_lod([lod])
res_label.set(bat_label, place)
res_label.set_lod([lod])
batch_id += 1
words_seen += lod[-1]
loss, acc = exe.run(
fluid.default_main_program(),
feed={"feature": res_feature,
"label": res_label},
fetch_list=[avg_cost] + accuracy.metrics,
return_numpy=False)
train_acc = accuracy.eval(exe)
print("acc:", lodtensor_to_ndarray(loss))
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
words_per_sec = words_seen / time_consumed
def lodtensor_to_ndarray(lod_tensor):
dims = lod_tensor.get_dims()
ret = np.zeros(shape=dims).astype('float32')
for i in xrange(np.product(dims)):
ret.ravel()[i] = lod_tensor.get_float_element(i)
return ret, lod_tensor.lod()
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
if args.infer_only:
pass
else:
if args.use_nvprof and args.device == 'GPU':
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
train(args)
else:
train(args)
#by zhxfl 2018.01.31
import sys
import unittest
import numpy as np
sys.path.append("../")
import data_utils.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.trans_add_delta as trans_add_delta
import data_utils.trans_splice as trans_splice
class TestTransMeanVarianceNorm(unittest.TestCase):
"""unit test for TransMeanVarianceNorm
"""
def test(self):
feature = np.zeros((2, 120), dtype="float32")
feature.fill(1)
trans = trans_mean_variance_norm.TransMeanVarianceNorm(
"../data/global_mean_var_search26kHr")
(feature1, label1) = trans.perform_trans((feature, None))
(mean, var) = trans.get_mean_var()
feature_flat1 = feature1.flatten()
feature_flat = feature.flatten()
one = np.ones((1), dtype="float32")
for idx, val in enumerate(feature_flat1):
cur_idx = idx % 120
self.assertAlmostEqual(val, (one[0] - mean[cur_idx]) * var[cur_idx])
class TestTransAddDelta(unittest.TestCase):
"""unit test TestTransAddDelta
"""
def test_regress(self):
"""test regress
"""
feature = np.zeros((14, 120), dtype="float32")
feature[0:5, 0:40].fill(1)
feature[0 + 5, 0:40].fill(1)
feature[1 + 5, 0:40].fill(2)
feature[2 + 5, 0:40].fill(3)
feature[3 + 5, 0:40].fill(4)
feature[8:14, 0:40].fill(4)
trans = trans_add_delta.TransAddDelta()
feature = feature.reshape((14 * 120))
trans._regress(feature, 5 * 120, feature, 5 * 120 + 40, 40, 4, 120)
trans._regress(feature, 5 * 120 + 40, feature, 5 * 120 + 80, 40, 4, 120)
feature = feature.reshape((14, 120))
tmp_feature = feature[5:5 + 4, :]
self.assertAlmostEqual(1.0, tmp_feature[0][0])
self.assertAlmostEqual(0.24, tmp_feature[0][119])
self.assertAlmostEqual(2.0, tmp_feature[1][0])
self.assertAlmostEqual(0.13, tmp_feature[1][119])
self.assertAlmostEqual(3.0, tmp_feature[2][0])
self.assertAlmostEqual(-0.13, tmp_feature[2][119])
self.assertAlmostEqual(4.0, tmp_feature[3][0])
self.assertAlmostEqual(-0.24, tmp_feature[3][119])
def test_perform(self):
"""test perform
"""
feature = np.zeros((4, 40), dtype="float32")
feature[0, 0:40].fill(1)
feature[1, 0:40].fill(2)
feature[2, 0:40].fill(3)
feature[3, 0:40].fill(4)
trans = trans_add_delta.TransAddDelta()
(feature, label) = trans.perform_trans((feature, None))
self.assertAlmostEqual(feature.shape[0], 4)
self.assertAlmostEqual(feature.shape[1], 120)
self.assertAlmostEqual(1.0, feature[0][0])
self.assertAlmostEqual(0.24, feature[0][119])
self.assertAlmostEqual(2.0, feature[1][0])
self.assertAlmostEqual(0.13, feature[1][119])
self.assertAlmostEqual(3.0, feature[2][0])
self.assertAlmostEqual(-0.13, feature[2][119])
self.assertAlmostEqual(4.0, feature[3][0])
self.assertAlmostEqual(-0.24, feature[3][119])
class TestTransSplict(unittest.TestCase):
"""unit test Test TransSplict
"""
def test_perfrom(self):
feature = np.zeros((8, 10), dtype="float32")
for i in xrange(feature.shape[0]):
feature[i, :].fill(i)
trans = trans_splice.TransSplice()
(feature, label) = trans.perform_trans((feature, None))
self.assertEqual(feature.shape[1], 110)
for i in xrange(8):
nzero_num = 5 - i
cur_val = 0.0
if nzero_num < 0:
cur_val = i - 5 - 1
for j in xrange(11):
if j <= nzero_num:
for k in xrange(10):
self.assertAlmostEqual(feature[i][j * 10 + k], cur_val)
else:
if cur_val < 7:
cur_val += 1.0
for k in xrange(10):
self.assertAlmostEqual(feature[i][j * 10 + k], cur_val)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册