提交 f1709803 编写于 作者: W walloollaw 提交者: qingqing01

caffe2fluid: support Reshape and Scale convertion; fix bug in dropout; add...

caffe2fluid: support Reshape and Scale convertion; fix bug in dropout; add more tools for comparing results (#880)
上级 98a47232
### Caffe2Fluid
This tool is used to convert a Caffe model to a Fluid model
### Key Features
1. Convert caffe model to fluid model with codes of defining a network(useful for re-training)
2. Pycaffe is not necessary when just want convert model without do caffe-inference
3. Caffe's customized layers convertion also be supported by extending this tool
4. A bunch of tools in '*examples/imagenet/tools*' are provided to compare the difference
### HowTo
1. Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here:
- Generate pycaffe from caffe.proto
......@@ -29,10 +38,10 @@ This tool is used to convert a Caffe model to a Fluid model
```
3. Use the converted model to infer
- See more details in '*examples/imagenet/run.sh*'
- See more details in '*examples/imagenet/tools/run.sh*'
4. Compare the inference results with caffe
- See more details in '*examples/imagenet/diff.sh*'
- See more details in '*examples/imagenet/tools/diff.sh*'
### How to convert custom layer
1. Implement your custom layer in a file under '*kaffe/custom_layers*', eg: mylayer.py
......
A demo to show converting caffe models on 'imagenet' using caffe2fluid
A demo to show converting caffe models trained on 'imagenet' using caffe2fluid
---
......@@ -10,28 +10,32 @@ A demo to show converting caffe models on 'imagenet' using caffe2fluid
3. Convert the Caffe model to Fluid model
- generate fluid code and weight file
<pre><code>python convert.py alexnet.prototxt \
```python convert.py alexnet.prototxt \
--caffemodel alexnet.caffemodel \
--data-output-path alexnet.npy \
--code-output-path alexnet.py
</code></pre>
```
- save weights as fluid model file
<pre><code>python alexnet.py alexnet.npy ./fluid_model
</code></pre>
```
python alexnet.py alexnet.npy ./fluid
```
4. Do inference
<pre><code>python infer.py infer ./fluid_mode data/65.jpeg
</code></pre>
```
python infer.py infer ./fluid data/65.jpeg
```
5. convert model and do inference together
<pre><code>bash ./run.sh alexnet ./models.caffe/alexnet ./models/alexnet
</code></pre>
The Caffe model is stored in './models.caffe/alexnet/alexnet.prototxt|caffemodel'
and the Fluid model will be save in './models/alexnet/alexnet.py|npy'
```
bash ./tools/run.sh alexnet ./models.caffe/alexnet ./models/alexnet
```
* Assume the Caffe model is stored in '*./models.caffe/alexnet/alexnet.prototxt|caffemodel*'
* converted model will be stored as '*./models/alexnet/alexnet.py|npy*'
6. test the difference with caffe's results(need pycaffe installed)
<pre><code>bash ./diff.sh resnet
</code></pre>
Make sure your caffemodel stored in './models.caffe/resnet'.
The results will be stored in './results/resnet.paddle|caffe'
```
bash ./tools/diff.sh resnet
```
* Make sure your caffemodel stored in '*./models.caffe/resnet*'
* The results will be stored in '*./results/resnet.paddle|caffe*'
......@@ -17,8 +17,21 @@ def walk_dir(rootdir):
def calc_diff(f1, f2):
import numpy as np
d1 = np.load(f1).flatten()
d2 = np.load(f2).flatten()
d1 = np.load(f1)
d2 = np.load(f2)
print d1.shape
print d2.shape
#print d1[0, 0, 0:10, 0:10]
#print d2[0, 0, 0:10, 0:10]
#d1 = d1[:, :, 1:-2, 1:-2]
#d2 = d2[:, :, 1:-2, 1:-2]
d1 = d1.flatten()
d2 = d2.flatten()
#print d1[:10]
#print d2[:10]
d1_num = reduce(lambda x, y: x * y, d1.shape)
d2_num = reduce(lambda x, y: x * y, d2.shape)
......@@ -36,15 +49,16 @@ def calc_diff(f1, f2):
return -1.0, -1.0
def compare(path1, path2):
def compare(path1, path2, no_exception):
def diff(f1, f2):
max_df, sq_df = calc_diff(f1, f2)
print('compare %s <=> %s with result[max_df:%.4e, sq_df:%.4e]' %
(f1, f2, max_df, sq_df))
assert (max_df < 1e-5), \
'max_df is too large with value[%.6e]' % (max_df)
assert (sq_df < 1e-10), \
'sq_df is too large with value[%.6e]' % (sq_df)
print('[max_df:%.4e, sq_df:%.4e] when compare %s <=> %s' %
(max_df, sq_df, os.path.basename(f1), os.path.basename(f2)))
if no_exception is False:
assert (max_df < 1e-5), \
'max_df is too large with value[%.6e]' % (max_df)
assert (sq_df < 1e-10), \
'sq_df is too large with value[%.6e]' % (sq_df)
if os.path.exists(path1) is False:
print('not found %s' % (path1))
......@@ -73,13 +87,17 @@ if __name__ == "__main__":
if len(sys.argv) == 1:
path1 = 'lenet.tf/results'
path2 = 'lenet.paddle/results'
elif len(sys.argv) == 3:
elif len(sys.argv) >= 3:
path1 = sys.argv[1]
path2 = sys.argv[2]
if len(sys.argv) == 4:
no_exception = True
else:
no_exception = False
else:
print('usage:')
print(' %s [path1] [path2]' % (sys.argv[0]))
exit(1)
print('compare inner result in %s %s' % (path1, path2))
exit(compare(path1, path2))
#print('compare inner result in %s %s' % (path1, path2))
exit(compare(path1, path2, no_exception))
......@@ -213,7 +213,6 @@ def caffe_infer(prototxt, caffemodel, datafile):
results = []
names = []
for k, v in net.blobs.items():
k = k.rstrip('_output')
k = k.replace('/', '_')
names.append(k)
results.append(v.data.copy())
......@@ -259,7 +258,7 @@ if __name__ == "__main__":
print('usage:')
print('\tpython %s dump [net_file] [weight_file] [datafile] [net_name]' \
% (sys.argv[0]))
print('\teg:python dump %s %s %s %s %s' % (sys.argv[0],\
print('\teg:python %s dump %s %s %s %s' % (sys.argv[0],\
net_file, weight_file, datafile, net_name))
sys.exit(1)
......
#!/bin/bash
#
#function:
# a tool used to compare the results produced by paddle and caffe
#
if [[ $# -lt 2 ]];then
echo "usage:"
echo " bash $0 [model_name] [param_name] [caffe_name]"
exit 1
fi
model_name=$1
param_name=$2
paddle_file="./results/${model_name}.paddle/${param_name}.npy"
if [[ $# -eq 3 ]];then
caffe_file="./results/${model_name}.caffe/${3}.npy"
else
caffe_file="./results/${model_name}.caffe/${2}.npy"
fi
python ./compare.py $paddle_file $caffe_file
#!/bin/bash
#function:
# a tool used to compare all layers' results
#
if [[ $# -ne 1 ]];then
echo "usage:"
echo " bash $0 [model_name]"
echo " eg:bash $0 alexnet"
exit 1
fi
model_name=$1
prototxt="models.caffe/$model_name/${model_name}.prototxt"
layers=$(cat $prototxt | perl -ne 'if(/^\s+name\s*:\s*\"([^\"]+)/){print $1."\n";}')
for i in $layers;do
cf_npy="results/${model_name}.caffe/${i}.npy"
pd_npy="results/${model_name}.paddle/${i}.npy"
if [[ ! -e $cf_npy ]];then
echo "caffe's result not exist[$cf_npy]"
continue
fi
if [[ ! -e $pd_npy ]];then
echo "paddle's result not exist[$pd_npy]"
continue
fi
python compare.py $cf_npy $pd_npy no_exception
if [[ $? -eq 0 ]];then
echo "succeed to compare layer[$i]"
else
echo "failed to compare layer[$i]"
fi
done
......@@ -36,7 +36,7 @@ model_caffemodel="models.caffe/${model_name}/${model_name}.caffemodel"
paddle_results="$results_root/${model_name}.paddle"
rm -rf $paddle_results
rm -rf "results.paddle"
bash run.sh $model_name ./models.caffe/$model_name ./models/$model_name
bash ./tools/run.sh $model_name ./models.caffe/$model_name ./models/$model_name
if [[ $? -ne 0 ]] || [[ ! -e "results.paddle" ]];then
echo "not found paddle's results, maybe failed to convert"
exit 1
......
......@@ -6,7 +6,7 @@
# 2, do inference(only in fluid) using this model
#
#usage:
# bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50
# cd caffe2fluid/examples/imagenet && bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50
#
#set -x
......
......@@ -7,13 +7,14 @@ from .register import get_registered_layers
import axpy
import flatten
import argmax
import reshape
#custom layer import ends
custom_layers = get_registered_layers()
def set_args(f, params):
def set_args(f, params, node=None):
""" set args for function 'f' using the parameters in node.layer.parameters
Args:
......@@ -24,19 +25,15 @@ def set_args(f, params):
arg_names (list): a list of argument names
kwargs (dict): a dict contains needed arguments
"""
from ..protobuf_to_dict import protobuf_to_dict
argc = f.__code__.co_argcount
arg_list = f.__code__.co_varnames[0:argc]
kwargs = {}
for arg_name in arg_list:
try:
v = getattr(params, arg_name, None)
except Exception as e:
#maybe failed to extract caffe's parameters
v = None
if v is not None:
kwargs[arg_name] = v
if arg_name in params:
kwargs[arg_name] = params[arg_name]
return arg_list, kwargs
......@@ -54,7 +51,7 @@ def compute_output_shape(kind, node):
parents = node.parents
inputs = [list(p.output_shape) for p in parents]
arg_names, kwargs = set_args(shape_func, node.layer.parameters)
arg_names, kwargs = set_args(shape_func, node.params)
if len(inputs) == 1:
inputs = inputs[0]
......@@ -80,7 +77,7 @@ def make_node(template, kind, node):
layer_func = custom_layers[kind]['layer']
#construct arguments needed by custom layer function from node's parameters
arg_names, kwargs = set_args(layer_func, node.layer.parameters)
arg_names, kwargs = set_args(layer_func, node.params, node)
return template('custom_layer', kind, **kwargs)
......
""" a custom layer for 'reshape', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/reshape.html
"""
from .register import register
def import_fluid():
import paddle.fluid as fluid
return fluid
def reshape_shape(input_sp, shape, axis=0, num_axes=-1):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@shape (object): parameter from caffe's Reshape layer
@axis (int): parameter from caffe's Reshape layer
@num_axes(int): parameter from caffe's Reshape layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
def count(num_list):
return reduce(lambda a, b: a * b, num_list)
input_shape = list(input_sp)
input_count = count(input_shape)
input_num_axes = len(input_shape)
input_start_axis = axis
start_axis = input_start_axis if input_start_axis >= 0 \
else input_num_axes + input_start_axis + 1
assert start_axis >= 0, "[Reshape]axis %d out of range" % (input_start_axis)
assert start_axis <= input_num_axes, "[Reshape]axis %d out of range for %d-D input data"\
% (input_start_axis, input_num_axes)
assert num_axes >= -1, "[Reshape]num_axes must be >= 0, or -1 for all"
end_axis = input_num_axes if num_axes == -1 else start_axis + num_axes
assert end_axis <= input_num_axes, "end_axis[%d] = axis[%d] + num_axes[%d] is out of range"\
% (end_axis, start_axis, num_axes)
num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(shape['dim'])
output_shape = []
for i in range(start_axis):
output_shape.append(input_shape[i])
for i in range(num_new_axes):
output_shape.append(shape['dim'][i])
for i in range(end_axis, input_num_axes):
output_shape.append(input_shape[i])
assert len(output_shape) == num_axes_retained + num_new_axes,\
"[Reshape]invalid dims of output shape[%s]" % (str(output_shape))
inferred_axis = -1
copy_axes = []
constant_count = 1
for i in range(num_new_axes):
top_dim = shape['dim'][i]
if top_dim == 0:
copy_axes.append(i)
elif top_dim == -1:
assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims"
else:
constant_count *= top_dim
if inferred_axis >= 0:
explicit_count = constant_count
explicit_count *= count(input_shape[0:start_axis])
explicit_count *= count(input_shape[end_axis:])
for i in range(len(copy_axes)):
explicit_count *= output_shape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count)
output_count = count(output_shape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
return output_shape
def reshape_layer(input, name, shape, axis=0, num_axes=-1):
""" build a layer of type 'Flatten' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@shape (object): parameter from caffe's Reshape layer
@axis (int): parameter from caffe's Reshape layer
@num_axes(int): parameter from caffe's Reshape layer
Returns:
output (variable): output variable for this layer
"""
fluid = import_fluid()
input_shape = list(input.shape)
if input_shape[0] == -1:
input_shape[0] = 1
output_shape = reshape_shape(input_shape, shape, axis, num_axes)
output_shape[0] = -1
else:
output_shape = reshape_shape(input_shape, shape, axis, num_axes)
output = fluid.layers.reshape(input, shape=output_shape, name=name)
return output
register(kind='Reshape', shape=reshape_shape, layer=reshape_layer)
......@@ -13,8 +13,8 @@ class Node(object):
self.layer = LayerAdapter(layer, kind) if layer else None
self.parents = []
self.children = []
self.data = None
self.output_shape = None
self.data = None #parameters of this node
self.output_shape = None #output shape of this node
self.metadata = {}
def add_parent(self, parent_node):
......@@ -37,10 +37,24 @@ class Node(object):
@property
def parameters(self):
""" get parameters stored in a protobuf object
"""
if self.layer is not None:
return self.layer.parameters
return None
@property
def params(self):
""" get parameters stored in a dict
"""
from .protobuf_to_dict import protobuf_to_dict
p = self.parameters
if p is not None:
return protobuf_to_dict(p)
else:
return None
def __str__(self):
return '[%s] %s' % (self.kind, self.name)
......
......@@ -22,15 +22,13 @@ def layer(op):
layer_input = self.terminals[0]
else:
layer_input = list(self.terminals)
# Perform the operation and get the output.
layer_output = op(self, layer_input, *args, **kwargs)
# Add to layer LUT.
self.layers[name] = layer_output
# This output is now the input for the next layer.
self.feed(layer_output)
#print('output shape of %s:' % (name))
#print layer_output.shape
# Return self for chained calls.
return self
......@@ -129,6 +127,7 @@ class Network(object):
s_w,
name,
relu=True,
relu_negative_slope=0.0,
padding=None,
group=1,
biased=True):
......@@ -144,6 +143,14 @@ class Network(object):
fluid = import_fluid()
prefix = name + '_'
leaky_relu = False
act = 'relu'
if relu is False:
act = None
elif relu_negative_slope != 0.0:
leaky_relu = True
act = None
output = fluid.layers.conv2d(
input=input,
filter_size=[k_h, k_w],
......@@ -153,7 +160,11 @@ class Network(object):
groups=group,
param_attr=fluid.ParamAttr(name=prefix + "weights"),
bias_attr=fluid.ParamAttr(name=prefix + "biases"),
act="relu" if relu is True else None)
act=act)
if leaky_relu:
output = fluid.layers.leaky_relu(output, alpha=relu_negative_slope)
return output
@layer
......@@ -286,8 +297,32 @@ class Network(object):
@layer
def dropout(self, input, drop_prob, name, is_test=True):
fluid = import_fluid()
output = fluid.layers.dropout(
input, dropout_prob=drop_prob, is_test=is_test, name=name)
if is_test:
output = input
else:
output = fluid.layers.dropout(
input, dropout_prob=drop_prob, is_test=is_test)
return output
@layer
def scale(self, input, axis=1, num_axes=1, name=None):
fluid = import_fluid()
assert num_axes == 1, "layer scale not support this num_axes[%d] now" % (
num_axes)
prefix = name + '_'
scale_shape = input.shape[axis:axis + num_axes]
param_attr = fluid.ParamAttr(name=prefix + 'scale')
scale_param = fluid.layers.create_parameter(
shape=scale_shape, dtype=input.dtype, name=name, attr=param_attr)
offset_attr = fluid.ParamAttr(name=prefix + 'offset')
offset_param = fluid.layers.create_parameter(
shape=scale_shape, dtype=input.dtype, name=name, attr=offset_attr)
output = fluid.layers.elementwise_mul(input, scale_param, axis=axis)
output = fluid.layers.elementwise_add(output, offset_param, axis=axis)
return output
def custom_layer_factory(self):
......
"""a util for convert protobuf to dict
"""
from google.protobuf.message import Message
from google.protobuf.descriptor import FieldDescriptor
__all__ = [
"protobuf_to_dict", "TYPE_CALLABLE_MAP", "dict_to_protobuf",
"REVERSE_TYPE_CALLABLE_MAP"
]
EXTENSION_CONTAINER = '___X'
TYPE_CALLABLE_MAP = {
FieldDescriptor.TYPE_DOUBLE: float,
FieldDescriptor.TYPE_FLOAT: float,
FieldDescriptor.TYPE_INT32: int,
FieldDescriptor.TYPE_INT64: long,
FieldDescriptor.TYPE_UINT32: int,
FieldDescriptor.TYPE_UINT64: long,
FieldDescriptor.TYPE_SINT32: int,
FieldDescriptor.TYPE_SINT64: long,
FieldDescriptor.TYPE_FIXED32: int,
FieldDescriptor.TYPE_FIXED64: long,
FieldDescriptor.TYPE_SFIXED32: int,
FieldDescriptor.TYPE_SFIXED64: long,
FieldDescriptor.TYPE_BOOL: bool,
FieldDescriptor.TYPE_STRING: unicode,
FieldDescriptor.TYPE_BYTES: lambda b: b.encode("base64"),
FieldDescriptor.TYPE_ENUM: int,
}
def repeated(type_callable):
return lambda value_list: [type_callable(value) for value in value_list]
def enum_label_name(field, value):
return field.enum_type.values_by_number[int(value)].name
def protobuf_to_dict(pb,
type_callable_map=TYPE_CALLABLE_MAP,
use_enum_labels=False):
result_dict = {}
extensions = {}
for field, value in pb.ListFields():
type_callable = _get_field_value_adaptor(pb, field, type_callable_map,
use_enum_labels)
if field.label == FieldDescriptor.LABEL_REPEATED:
type_callable = repeated(type_callable)
if field.is_extension:
extensions[str(field.number)] = type_callable(value)
continue
result_dict[field.name] = type_callable(value)
if extensions:
result_dict[EXTENSION_CONTAINER] = extensions
return result_dict
def _get_field_value_adaptor(pb,
field,
type_callable_map=TYPE_CALLABLE_MAP,
use_enum_labels=False):
if field.type == FieldDescriptor.TYPE_MESSAGE:
# recursively encode protobuf sub-message
return lambda pb: protobuf_to_dict(pb,
type_callable_map=type_callable_map,
use_enum_labels=use_enum_labels)
if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM:
return lambda value: enum_label_name(field, value)
if field.type in type_callable_map:
return type_callable_map[field.type]
raise TypeError("Field %s.%s has unrecognised type id %d" %
(pb.__class__.__name__, field.name, field.type))
def get_bytes(value):
return value.decode('base64')
REVERSE_TYPE_CALLABLE_MAP = {FieldDescriptor.TYPE_BYTES: get_bytes, }
def dict_to_protobuf(pb_klass_or_instance,
values,
type_callable_map=REVERSE_TYPE_CALLABLE_MAP,
strict=True):
"""Populates a protobuf model from a dictionary.
:param pb_klass_or_instance: a protobuf message class, or an protobuf instance
:type pb_klass_or_instance: a type or instance of a subclass of google.protobuf.message.Message
:param dict values: a dictionary of values. Repeated and nested values are
fully supported.
:param dict type_callable_map: a mapping of protobuf types to callables for setting
values on the target instance.
:param bool strict: complain if keys in the map are not fields on the message.
"""
if isinstance(pb_klass_or_instance, Message):
instance = pb_klass_or_instance
else:
instance = pb_klass_or_instance()
return _dict_to_protobuf(instance, values, type_callable_map, strict)
def _get_field_mapping(pb, dict_value, strict):
field_mapping = []
for key, value in dict_value.items():
if key == EXTENSION_CONTAINER:
continue
if key not in pb.DESCRIPTOR.fields_by_name:
if strict:
raise KeyError("%s does not have a field called %s" % (pb, key))
continue
field_mapping.append(
(pb.DESCRIPTOR.fields_by_name[key], value, getattr(pb, key, None)))
for ext_num, ext_val in dict_value.get(EXTENSION_CONTAINER, {}).items():
try:
ext_num = int(ext_num)
except ValueError:
raise ValueError("Extension keys must be integers.")
if ext_num not in pb._extensions_by_number:
if strict:
raise KeyError(
"%s does not have a extension with number %s. Perhaps you forgot to import it?"
% (pb, key))
continue
ext_field = pb._extensions_by_number[ext_num]
pb_val = None
pb_val = pb.Extensions[ext_field]
field_mapping.append((ext_field, ext_val, pb_val))
return field_mapping
def _dict_to_protobuf(pb, value, type_callable_map, strict):
fields = _get_field_mapping(pb, value, strict)
for field, input_value, pb_value in fields:
if field.label == FieldDescriptor.LABEL_REPEATED:
for item in input_value:
if field.type == FieldDescriptor.TYPE_MESSAGE:
m = pb_value.add()
_dict_to_protobuf(m, item, type_callable_map, strict)
elif field.type == FieldDescriptor.TYPE_ENUM and isinstance(
item, basestring):
pb_value.append(_string_to_enum(field, item))
else:
pb_value.append(item)
continue
if field.type == FieldDescriptor.TYPE_MESSAGE:
_dict_to_protobuf(pb_value, input_value, type_callable_map, strict)
continue
if field.type in type_callable_map:
input_value = type_callable_map[field.type](input_value)
if field.is_extension:
pb.Extensions[field] = input_value
continue
if field.type == FieldDescriptor.TYPE_ENUM and isinstance(input_value,
basestring):
input_value = _string_to_enum(field, input_value)
setattr(pb, field.name, input_value)
return pb
def _string_to_enum(field, input_value):
enum_dict = field.enum_type.values_by_name
try:
input_value = enum_dict[input_value].number
except KeyError:
raise KeyError("`%s` is not a valid value for field `%s`" %
(input_value, field.name))
return input_value
......@@ -150,8 +150,8 @@ class DataReshaper(object):
if node.kind not in self.reshaped_node_types:
# Check for 2+ dimensional data
if any(len(tensor.shape) > 1 for tensor in node.data):
notice('parmaters not reshaped for node: {}'.format(node))
#if any(len(tensor.shape) > 1 for tensor in node.data):
# notice('parmaters not reshaped for node: {}'.format(node))
continue
transpose_order = self.map(node.kind)
......@@ -233,8 +233,9 @@ class ReLUFuser(SubNodeFuser):
parent.kind in self.allowed_parent_types) and \
child.kind == NodeKind.ReLU)
def merge(self, parent, _):
def merge(self, parent, child):
parent.metadata['relu'] = True
parent.metadata['relu_negative_slope'] = child.parameters.negative_slope
class BatchNormScaleBiasFuser(SubNodeFuser):
......@@ -316,8 +317,11 @@ class ParameterNamer(object):
names = ('mean', 'variance')
if len(node.data) == 4:
names += ('scale', 'offset')
elif node.kind == NodeKind.Scale:
names = ('scale', 'offset')
else:
warn('Unhandled parameters: {}'.format(node.kind))
warn('Unhandled parameters when naming this it[%s]' %
(node.kind))
continue
assert len(names) == len(node.data)
node.data = dict(zip(names, node.data))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册