未验证 提交 3be647ed 编写于 作者: M mamingjie-China 提交者: GitHub

Merge pull request #1 from PaddlePaddle/develop

Update
......@@ -11,3 +11,6 @@ x2paddle -f tensorflow -m tf.pb -s pd-model --without_data_format_optimization -
```
> 1. 目前Tensorflow的CV模型大部分均为`NHWC`的输入格式,而Paddle的默认输入格式为`NCHW`,因此X2Paddle在转换过程中,会对如`axis`, `shape`等参数进行转换,适应Paddle的NCHW格式。但在这种情况下,可能会由于TensorFlow模型太复杂,导致出错。 指定`--without_data_format_optimization`后,会停止对`axis`,`shape`等参数的优化(这可能会带来一定数量的transpose操作)
**Q3. ONNX模型转换过程中,提示『Unknown shape for input tensor[tensor name: "input"] -> shape: ['batch', 'sequence'], Please define shape of input here』**
A:该提示信息表示从ONNX的模型中获取到输入tensor(tensor名为"input:)的shape是语义象征性的['batch', 'sequence'],而不是dim为int类型的shape,从而可能会因为部分node的shape无法推理,导致转换失败。所以用户可以尝试手动在提示后输入详细的shape信息,如:-1,3,224,224 其中-1表示Batch
......@@ -10,12 +10,12 @@ X2Paddle在多个主流的CV模型上,测试过TensorFlow/Caffe/ONNX模型的
## 环境依赖
python == 2.7 | python >= 3.5
paddlepaddle >= 1.6.0
paddlepaddle >= 1.8.0
**按需安装以下依赖**
tensorflow : tensorflow == 1.14.0
caffe : 无
onnx : onnx == 1.6.0 onnxruntime == 1.0.0
onnx : onnx == 1.6.0
## 安装
### 安装方式一(推荐)
......@@ -63,6 +63,7 @@ x2paddle --framework=paddle2onnx --model=paddle_infer_model_dir --save_dir=onnx_
|--params_merge | **[可选]** 当指定该参数时,转换完成后,inference_model中的所有模型参数将合并保存为一个文件__params__ |
## 使用转换后的模型
转换后的模型包括`model_with_code``inference_model`两个目录。
`model_with_code`中保存了模型参数,和转换后的python模型代码
......
......@@ -175,13 +175,15 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
import onnxruntime
model = ONNXDecoder(model_path)
mapper = ONNXOpMapper(model, save_dir)
mapper = ONNXOpMapper(model)
print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper)
print("Model optimized.")
optimizer.delete_redundance_code()
print("Paddle model and code generating ...")
mapper.save_inference_model(save_dir, params_merge)
print("Paddle model and code generated.")
def paddle2onnx(model_path, save_dir):
......@@ -211,18 +213,6 @@ def main():
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined"
if args.framework == "onnx":
try:
import onnxruntime as rt
version = rt.__version__
if version != '1.0.0':
print("[ERROR] onnxruntime==1.0.0 is required")
return
except:
print(
"[ERROR] onnxruntime is not installed, use \"pip install onnxruntime==1.0.0\"."
)
try:
import paddle
v0, v1, v2 = paddle.__version__.split('.')
......@@ -261,6 +251,7 @@ def main():
elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model"
params_merge = False
if args.params_merge:
params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge)
......
......@@ -25,6 +25,7 @@ class Layer(object):
self.inputs = dict()
self.output = None
self.is_custom_layer = False
self.use_fluid = False
def get_code(self):
layer_code = ""
......@@ -38,6 +39,8 @@ class Layer(object):
layer_code = layer_code + self.op + "("
elif self.op == "=":
layer_code = layer_code
elif self.use_fluid:
layer_code = layer_code + "fluid." + self.op + "("
else:
layer_code = layer_code + "fluid.layers." + self.op + "("
......@@ -108,9 +111,11 @@ class FluidCode(object):
inputs,
output,
param_attr=None,
use_fluid=False,
is_custom_layer=False):
layer = Layer()
layer.op = op
layer.use_fluid = use_fluid
layer.is_custom_layer = is_custom_layer
if inputs is not None:
layer.inputs = inputs
......
......@@ -29,11 +29,14 @@ def export_paddle_param(param, param_name, dir):
"bool": [framework_pb2.VarType.BOOL, None]
}
shape = param.shape
if str(param.dtype) in ['uint8', 'uint_8', 'bool']:
param = param.astype('int64')
if len(shape) == 0:
assert param.size == 1, "Unexpected situation happend!"
shape = [1]
assert str(param.dtype) in dtype_map, "Unknown dtype of params."
assert str(
param.dtype) in dtype_map, "Unknown dtype {} of params: {}.".format(
str(param.dtype), param_name)
fp = open(os.path.join(dir, param_name), 'wb')
numpy.array([0], dtype='int32').tofile(fp)
numpy.array([0], dtype='int64').tofile(fp)
......
......@@ -14,6 +14,7 @@
from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode
from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference
from onnx.checker import ValidationError
from onnx.checker import check_model
from onnx.utils import polish_model
......@@ -53,7 +54,7 @@ class ONNXGraphNode(GraphNode):
convert ONNX node attributes to dict
"""
return {
attr.name: self.get_attribute_value2(attr)
attr.name: self.get_attribute_value(attr)
for attr in self.layer.attribute
}
......@@ -64,7 +65,7 @@ class ONNXGraphNode(GraphNode):
return None
return self.attr_map['value']
def get_attribute_value2(self, attr):
def get_attribute_value(self, attr):
"""
get_attribute_value enhanced
"""
......@@ -130,43 +131,90 @@ class ONNXGraphDataNode(GraphNode):
class ONNXGraph(Graph):
def __init__(self, onnx_model):
super(ONNXGraph, self).__init__(onnx_model.graph)
self.onnx_model = onnx_model
super(ONNXGraph, self).__init__(onnx_model)
self.fixed_input_shape = {}
self.initializer = {}
self.place_holder_nodes = list()
self.value_infos = {}
self.graph = onnx_model.graph
self.get_place_holder_nodes()
self.value_infos = self.inferred_model_value_info(self.model)
self.results_of_inference = dict()
print("shape inferencing ...")
self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape)
print("shape inferenced.")
self.build()
self.collect_value_infos()
self.allocate_shapes()
def get_inner_nodes(self):
"""
generate inner node of ONNX model
"""
inner_nodes = []
if not isinstance(self.model, onnx.GraphProto):
if not isinstance(self.graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
return
for initializer in self.model.initializer:
for initializer in self.graph.initializer:
name = initializer.name
inner_nodes.append(name)
return inner_nodes
def get_symbolic_shape(self, dims):
shape = []
for dim in dims:
if dim.HasField('dim_param'):
shape.append(dim.dim_param)
else:
shape.append(dim.dim_value)
return shape
def check_input_shape(self, vi):
if vi.type.HasField('tensor_type'):
for dim in vi.type.tensor_type.shape.dim:
if dim.HasField(
'dim_param') and vi.name not in self.fixed_input_shape:
shape = self.get_symbolic_shape(
vi.type.tensor_type.shape.dim)
print(
"Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
.format(vi.name, shape))
right_shape_been_input = False
while not right_shape_been_input:
try:
shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
except:
shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
if shape.count("-1") > 1:
print("Only 1 dimension can be -1, type again:)")
else:
right_shape_been_input = True
if shape == 'N':
break
shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape
break
def get_place_holder_nodes(self):
"""
generate place_holder node of ONNX model
"""
inner_nodes = self.get_inner_nodes()
input_nodes = [value.name for value in self.model.input]
for ipt_data in input_nodes:
if ipt_data not in inner_nodes:
self.place_holder_nodes.append(ipt_data)
for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes:
self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name)
def get_output_nodes(self):
"""
generate output_nodes node of ONNX model
"""
inner_nodes = self.get_inner_nodes()
output_nodes = [value.name for value in self.model.output]
output_nodes = [value.name for value in self.graph.output]
for opt_data in output_nodes:
if opt_data not in inner_nodes:
self.output_nodes.append(opt_data)
......@@ -183,11 +231,11 @@ class ONNXGraph(Graph):
"""
build topo_sort of ONNX model
"""
for layer in self.model.node:
for layer in self.graph.node:
node = ONNXGraphNode(layer)
self.node_map[layer.name] = node
for layer in self.model.input:
for layer in self.graph.input:
if layer.name not in self.node_map:
is_place_holder = self.is_place_holder_nodes(layer.name)
self.node_map[layer.name] = ONNXGraphDataNode(
......@@ -196,7 +244,7 @@ class ONNXGraph(Graph):
is_global_input=is_place_holder)
#set data node's weight
for initializer in self.model.initializer:
for initializer in self.graph.initializer:
name = initializer.name
weight = to_array(initializer)
if name in self.node_map:
......@@ -228,7 +276,7 @@ class ONNXGraph(Graph):
continue
if in_node not in self.node_map:
flag = 0
for nd in self.model.node:
for nd in self.graph.node:
for idx, opt in enumerate(nd.output):
if opt == in_node:
self.connect(nd.name, layer_name)
......@@ -256,81 +304,68 @@ class ONNXGraph(Graph):
ipt_node.index = node.which_child[ipt_node.layer_name]
return ipt_node
def graph_weights(self, graph):
def graph_weights(self):
"""
generator for weights
"""
if not isinstance(graph, onnx.GraphProto):
if not isinstance(self.graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
return
for initializer in graph.initializer:
for initializer in self.graph.initializer:
name = initializer.name
weight = to_array(initializer)
yield name, weight
def inferred_model_value_info(self, graph):
def collect_value_infos(self):
"""
collect value/type info for an ONNX model
"""
assert isinstance(graph,
assert isinstance(self.graph,
onnx.GraphProto), 'model is not a ModelProto instance'
value_info = Dict()
for item in graph.value_info:
value_info[item.name] = {
for item in self.graph.value_info:
self.value_infos[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': False
}
for item in graph.input:
assert item.name not in value_info
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': True
}
for item in graph.output:
assert item.name not in value_info
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': True
}
return value_info
def allocate_shapes(self):
"""
run shape inference
"""
for layer in self.graph.node:
node = self.node_map[layer.name]
for opt in layer.output:
if opt in self.value_infos:
value_info = self.value_infos[opt]
#if len(value_info['shape']) == 0 or value_info[
# 'dtype'] is None or 0 in value_info['shape']:
# #TODO add node shape inference
node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
else:
node.out_shapes.append([])
class ONNXDecoder(object):
def __init__(self, onnx_model):
model = onnx.load(onnx_model)
onnx_model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format(
model.ir_version, model.opset_import[0].version))
if model.opset_import[0].version < 9:
_logger.warning(
'Now, onnx2paddle support convert onnx model opset_verison == 9,'
'opset_verison of your onnx model is %d < 9,'
'some operator maybe unsuccessful in convertion.',
model.opset_import[0].version)
check_model(model)
self.check_model_running_state(onnx_model)
model = onnx.shape_inference.infer_shapes(model)
model = self.optimize_model_skip_op_for_inference(model)
model = self.optimize_model_strip_initializer(model)
self.standardize_variable_name(model.graph)
self.model = model
graph = model.graph
self.onnx_graph = ONNXGraph(model)
self.onnx_graph.build()
onnx_model.ir_version, onnx_model.opset_import[0].version))
self.op_set = onnx_model.opset_import[0].version
check_model(onnx_model)
onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_model_strip_initializer(onnx_model)
onnx_model = self.optimize_node_name(onnx_model)
self.graph = ONNXGraph(onnx_model)
#self.onnx_model = onnx_model
def build_value_refs(self, nodes):
"""
......@@ -373,14 +408,13 @@ class ONNXDecoder(object):
processed += 1
return processed
def optimize_model_skip_op_for_inference(self, model, op_list=None):
def optimize_model_skip_op(self, model, op_list=None):
"""
skip ops can be bypassed for inference
"""
nodes = model.graph.node
if op_list is None:
op_list = ['Dropout']
nodes = model.graph.node
input_refs, output_refs = self.build_value_refs(nodes)
ret = type(model)()
ret.CopyFrom(model)
......@@ -473,38 +507,11 @@ class ONNXDecoder(object):
name = name.replace(s, '_')
return 'x2paddle_' + name
def check_model_running_state(self, model_path):
import onnxruntime as rt
model = onnx.load(model_path)
model = onnx.shape_inference.infer_shapes(model)
if len(model.graph.value_info) < len(model.graph.node) - 1:
_logger.warning(
'During conversion of your model, some operators will be assignd node.out_shape==None, '
'refer to https://github.com/onnx/onnx/blob/master/docs/ShapeInference.md'
)
try:
datatype_map = {
'tensor(int64)': 'int',
'tensor(float)': 'float32',
'tensor(int32)': 'int32'
}
input_dict = {}
sess = rt.InferenceSession(model_path)
for ipt in sess.get_inputs():
datatype = datatype_map[ipt.type]
input_dict[ipt.name] = np.random.random(ipt.shape).astype(
datatype)
res = sess.run(None, input_feed=input_dict)
except:
raise Exception(
"onnxruntime inference onnx model failed, Please confirm the correctness of onnx model by onnxruntime, if onnx model is correct, please submit issue in github."
)
def standardize_variable_name(self, graph):
def optimize_node_name(self, model):
"""
standardize variable name for paddle's code
"""
graph = model.graph
for initializer in graph.initializer:
initializer.name = self.make_variable_name(initializer.name)
for ipt in graph.input:
......@@ -523,3 +530,4 @@ class ONNXDecoder(object):
node.input[i] = self.make_variable_name(node.input[i])
for i in range(len(node.output)):
node.output[i] = self.make_variable_name(node.output[i])
return model
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict as _dict
import numpy as _np
default_op_mapping_field_values = _dict()
default_op_mapping_field_values['FLUID_OP'] = ''
default_op_mapping_field_values['FLUID_INPUT_ARGS'] = None
default_op_mapping_field_values['FLUID_OUTPUT_ARGS'] = None
default_op_mapping_field_values['ATTR_MAPPING'] = dict()
default_op_mapping_field_values['DEFAULTS'] = dict()
default_op_mapping_field_values['INPUT_PERM'] = None
default_op_mapping_field_values['OUTPUT_PERM'] = None
default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = {
'Shape': ['shape', ['X'], ['Out']],
'Clip': [
'clip', ['X'], ['Out'], dict(), dict(
min=(_np.asarray(
[255, 255, 127, 255], dtype=_np.uint8).view(_np.float32)[0]),
max=(_np.asarray(
[255, 255, 127, 127], dtype=_np.uint8).view(_np.float32)[0]), )
],
'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [
'reduce_mean', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceSum': [
'reduce_sum', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceMin': [
'reduce_min', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceMax': [
'reduce_max', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
#active function
'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'], dict(), dict(alpha=.01)],
'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)],
'ThresholdedRelu': [
'thresholded_relu', ['X'], ['Out'], dict(alpha='threshold'),
dict(alpha=1.)
],
'Tanh': ['tanh', ['X'], ['Out']],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'HardSigmoid': [
'hard_sigmoid', ['X'], ['Out'], dict(
alpha='slope', beta='offset'), dict(
slope=.2, offset=.5)
],
'Softsign': ['softsign', ['X'], ['Out']],
'Softplus': ['softplus', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'], dict(), dict(axis=1)],
'Sqrt': ['sqrt', ['X'], ['Out']],
'Floor': ['floor', ['X'], ['Out']],
'Abs': ['abs', ['X'], ['Out']],
}
default_ioa_constraint = {
'Gather': [(lambda i, o, a: a.get('axis', 0) == 0,
'only axis = 0 is supported')],
}
此差异已折叠。
......@@ -2,6 +2,8 @@ import onnx
import numpy as np
from onnx import onnx_pb, helper
MAX_FLOAT = np.asarray([255, 255, 127, 127], dtype=np.uint8).view(np.float32)[0]
def get_old_name(arg, name_prefix=''):
prefix_index = arg.find(name_prefix)
......@@ -747,36 +749,53 @@ def yolo_box(op, block):
outputs_pred_box_x2_clip = [model_name + "@pred_box_x2_clip"]
outputs_pred_box_y2_clip = [model_name + "@pred_box_y2_clip"]
min_const_name = model_name + "@pred_box_min_const"
max_const_name = model_name + "@pred_box_max_const"
min_const = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=[min_const_name],
value=onnx.helper.make_tensor(
name=min_const_name,
data_type=onnx.TensorProto.FLOAT,
dims=(),
vals=[0.0]))
node_list.append(min_const)
max_const = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=[max_const_name],
value=onnx.helper.make_tensor(
name=max_const_name,
data_type=onnx.TensorProto.FLOAT,
dims=(),
vals=[MAX_FLOAT]))
node_list.append(max_const)
node_pred_box_x1_clip = onnx.helper.make_node(
'Clip',
inputs=outputs_pred_box_x1_decode,
outputs=outputs_pred_box_x1_clip,
min=0.0,
max=float(np.inf))
inputs=outputs_pred_box_x1_decode + [min_const_name, max_const_name],
outputs=outputs_pred_box_x1_clip)
node_list.append(node_pred_box_x1_clip)
node_pred_box_y1_clip = onnx.helper.make_node(
'Clip',
inputs=outputs_pred_box_y1_decode,
outputs=outputs_pred_box_y1_clip,
min=0.0,
max=float(np.inf))
inputs=outputs_pred_box_y1_decode + [min_const_name, max_const_name],
outputs=outputs_pred_box_y1_clip)
node_list.append(node_pred_box_y1_clip)
node_pred_box_x2_clip = onnx.helper.make_node(
'Clip',
inputs=outputs_pred_box_x2_sub_w,
outputs=outputs_pred_box_x2_clip,
min=0.0,
max=float(np.inf))
inputs=outputs_pred_box_x2_sub_w + [min_const_name, max_const_name],
outputs=outputs_pred_box_x2_clip)
node_list.append(node_pred_box_x2_clip)
node_pred_box_y2_clip = onnx.helper.make_node(
'Clip',
inputs=outputs_pred_box_y2_sub_h,
outputs=outputs_pred_box_y2_clip,
min=0.0,
max=float(np.inf))
inputs=outputs_pred_box_y2_sub_h + [min_const_name, max_const_name],
outputs=outputs_pred_box_y2_clip)
node_list.append(node_pred_box_y2_clip)
outputs_pred_box_x2_res = [model_name + "@box_x2_res"]
......
......@@ -13,7 +13,6 @@
# limitations under the License.
# TODO useless node remove
from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper
class ONNXOptimizer(object):
......
......@@ -49,8 +49,8 @@
## ONNX
**注:** 部分模型来源于PyTorch,PyTorch的转换可参考[pytorch_to_onnx.md](pytorch_to_onnx.md)
| 模型 | 来源 | operator version|
|-------|--------|---------|
| 模型 | 来源 | operator version|备注|
|-------|--------|---------|---------|
| ResNet18 | [torchvison.model.resnet18](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9|
| ResNet34 | [torchvison.model.resnet34](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9|
| ResNet50 | [torchvison.model.resnet50](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9|
......@@ -66,4 +66,6 @@
| mNASNet | [pytorch(personal practice)](https://github.com/rwightman/gen-efficientnet-pytorch) |9|
| EfficientNet | [pytorch(personal practice)](https://github.com/rwightman/gen-efficientnet-pytorch) |9|
| SqueezeNet | [onnx official](https://s3.amazonaws.com/download.onnx/models/opset_9/squeezenet.tar.gz) |9|
|Ultra-Light-Fast-Generic-Face-Detector-1MB| [onnx_model](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/models/onnx)| |
|Ultra-Light-Fast-Generic-Face-Detector-1MB| [onnx_model](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/models/onnx)|9 |
|BERT| [pytorch(huggingface)](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb)|11|转换时需指定input shape,见[文档Q3](FAQ.md)|
|GPT2| [pytorch(huggingface)](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb)|11|转换时需指定input shape,见[文档Q3](FAQ.md)|
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册