未验证 提交 40f54537 编写于 作者: H huangxu96 提交者: GitHub

Quant nn2.0 (#28764)

* Impelement 2.0 API version Conv2d and Linear layer quantization in imperative mode.

* use cudnn softmax in static Lenet

* Modified ChannelwiseQAT Unittest for 2.0 API.

* For CI python coverage.
上级 b2c8a007
...@@ -20,7 +20,8 @@ import paddle ...@@ -20,7 +20,8 @@ import paddle
from paddle.fluid import dygraph, core, framework from paddle.fluid import dygraph, core, framework
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.dygraph.nn import Conv2D, Linear, BatchNorm, Pool2D, Conv2DTranspose from paddle.nn import Linear, Conv2D
from paddle.fluid.dygraph.nn import BatchNorm, Pool2D, Conv2DTranspose
from paddle.fluid.io import load_inference_model, save_inference_model from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
...@@ -142,6 +143,8 @@ class ImperativeQuantAware(object): ...@@ -142,6 +143,8 @@ class ImperativeQuantAware(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._weight_pre_layer = weight_preprocess_layer self._weight_pre_layer = weight_preprocess_layer
self._act_pre_layer = act_preprocess_layer self._act_pre_layer = act_preprocess_layer
...@@ -172,8 +175,6 @@ class ImperativeQuantAware(object): ...@@ -172,8 +175,6 @@ class ImperativeQuantAware(object):
"Unknown weight_quantize_type: '%s'. It can only be " "Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'moving_average_abs_max' or 'channel_wise_abs_max' now." "'abs_max' or 'moving_average_abs_max' or 'channel_wise_abs_max' now."
% (str(weight_quantize_type))) % (str(weight_quantize_type)))
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._quant_layers_map = {'Conv2D': Conv2D, 'Linear': Linear} self._quant_layers_map = {'Conv2D': Conv2D, 'Linear': Linear}
self._quantizable_layer_type = tuple( self._quantizable_layer_type = tuple(
......
...@@ -21,6 +21,7 @@ from paddle.fluid.framework import _varbase_creator ...@@ -21,6 +21,7 @@ from paddle.fluid.framework import _varbase_creator
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Constant
from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.nn import functional as F
__all__ = [ __all__ = [
'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D', 'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D',
...@@ -144,7 +145,6 @@ class FakeQuantAbsMax(layers.Layer): ...@@ -144,7 +145,6 @@ class FakeQuantAbsMax(layers.Layer):
quant_on_weight=False): quant_on_weight=False):
super(FakeQuantAbsMax, self).__init__() super(FakeQuantAbsMax, self).__init__()
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._dtype = dtype
self._name = name self._name = name
scale_prefix = "{}.scale".format( scale_prefix = "{}.scale".format(
name) if name else 'quant_dequant.scale' name) if name else 'quant_dequant.scale'
...@@ -342,16 +342,17 @@ class QuantizedConv2D(layers.Layer): ...@@ -342,16 +342,17 @@ class QuantizedConv2D(layers.Layer):
self._groups = getattr(layer, '_groups') self._groups = getattr(layer, '_groups')
self._stride = getattr(layer, '_stride') self._stride = getattr(layer, '_stride')
self._padding = getattr(layer, '_padding') self._padding = getattr(layer, '_padding')
self._padding_mode = getattr(layer, '_padding_mode')
if self._padding_mode != 'zeros':
self._reversed_padding_repeated_twice = getattr(
layer, '_reversed_padding_repeated_twice')
self._dilation = getattr(layer, '_dilation') self._dilation = getattr(layer, '_dilation')
self._act = getattr(layer, '_act') self._data_format = getattr(layer, '_data_format')
self._use_cudnn = getattr(layer, '_use_cudnn')
self._dtype = getattr(layer, '_dtype')
self._l_type = getattr(layer, '_l_type')
self.weight = getattr(layer, 'weight') self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias') self.bias = getattr(layer, 'bias')
# For FakeQuant # For FakeQuant
self._conv2d_quant_axis = 0 self._conv2d_quant_axis = 0
if weight_quant_layer is not None: if weight_quant_layer is not None:
self._fake_quant_weight = weight_quant_layer() self._fake_quant_weight = weight_quant_layer()
else: else:
...@@ -390,52 +391,22 @@ class QuantizedConv2D(layers.Layer): ...@@ -390,52 +391,22 @@ class QuantizedConv2D(layers.Layer):
weight = self._weight_preprocess(self.weight) weight = self._weight_preprocess(self.weight)
quant_weight = self._fake_quant_weight(weight) quant_weight = self._fake_quant_weight(weight)
if in_dygraph_mode() and self._l_type == 'conv2d': if self._padding_mode != 'zeros':
attrs = ('strides', self._stride, 'paddings', self._padding, quant_input = F.pad(quant_input,
'dilations', self._dilation, 'groups', self._groups self._reversed_padding_repeated_twice,
if self._groups else 1, 'use_cudnn', self._use_cudnn) mode=self._padding_mode,
pre_bias = core.ops.conv2d(quant_input, quant_weight, *attrs) data_format=self._data_format)
self._padding = 0
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, self.bias,
1)
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
check_variable_and_dtype(quant_input, 'input',
['float16', 'float32', 'float64'],
'QuantizedConv2D')
attrs = {
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups if self._groups else 1,
'use_cudnn': self._use_cudnn,
'use_mkldnn': False,
}
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type=self._l_type,
inputs={
'Input': quant_input,
'Filter': quant_weight,
},
outputs={"Output": pre_bias},
attrs=attrs)
if self.bias is not None:
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [self.bias]},
outputs={'Out': [pre_act]},
attrs={'axis': 1})
else:
pre_act = pre_bias
return self._helper.append_activation(pre_act, act=self._act) return F.conv2d(
quant_input,
quant_weight,
bias=self.bias,
padding=self._padding,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format)
class QuantizedLinear(layers.Layer): class QuantizedLinear(layers.Layer):
...@@ -457,10 +428,9 @@ class QuantizedLinear(layers.Layer): ...@@ -457,10 +428,9 @@ class QuantizedLinear(layers.Layer):
act_quant_layer=None): act_quant_layer=None):
super(QuantizedLinear, self).__init__() super(QuantizedLinear, self).__init__()
# For Linear # For Linear
self._act = getattr(layer, '_act')
self._dtype = getattr(layer, '_dtype')
self.weight = getattr(layer, 'weight') self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias') self.bias = getattr(layer, 'bias')
self.name = getattr(layer, 'name')
# For FakeQuant # For FakeQuant
self._linear_quant_axis = 1 self._linear_quant_axis = 1
...@@ -503,44 +473,9 @@ class QuantizedLinear(layers.Layer): ...@@ -503,44 +473,9 @@ class QuantizedLinear(layers.Layer):
weight = self._weight_preprocess(self.weight) weight = self._weight_preprocess(self.weight)
quant_weight = self._fake_quant_weight(weight) quant_weight = self._fake_quant_weight(weight)
if in_dygraph_mode(): out = F.linear(
pre_bias = _varbase_creator(dtype=input.dtype) x=quant_input, weight=quant_weight, bias=self.bias, name=self.name)
core.ops.matmul(quant_input, quant_weight, pre_bias, 'transpose_X', return out
False, 'transpose_Y', False, "alpha", 1)
pre_act = dygraph_utils._append_bias_in_dygraph(
pre_bias, self.bias, axis=len(input.shape) - 1)
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'],
"QuantizedLinear")
attrs = {
"transpose_X": False,
"transpose_Y": False,
"alpha": 1,
}
inputs = {"X": [quant_input], "Y": [quant_weight]}
mul_out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="matmul",
inputs=inputs,
outputs={"Out": [mul_out]},
attrs=attrs)
if self.bias is not None:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [mul_out],
'Y': [self.bias]},
outputs={'Out': [pre_activation]},
attrs={'axis': len(input.shape) - 1})
else:
pre_activation = mul_out
return self._helper.append_activation(pre_activation, act=self._act)
class MovingAverageAbsMaxScale(layers.Layer): class MovingAverageAbsMaxScale(layers.Layer):
......
...@@ -27,11 +27,11 @@ from paddle.fluid.framework import IrGraph ...@@ -27,11 +27,11 @@ from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.nn import Conv2D from paddle.nn import Linear, Conv2D, Softmax
from paddle.fluid.dygraph.nn import Pool2D from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.contrib.slim.quantization.imperative.quant_nn import QuantizedConv2D
paddle.enable_static() paddle.enable_static()
...@@ -43,7 +43,7 @@ _logger = get_logger( ...@@ -43,7 +43,7 @@ _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def StaticLenet(data, num_classes=10, classifier_activation='softmax'): def StaticLenet(data, num_classes=10):
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1") fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
...@@ -85,15 +85,15 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'): ...@@ -85,15 +85,15 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
bias_attr=fc_b2_attr) bias_attr=fc_b2_attr)
fc3 = fluid.layers.fc(input=fc2, fc3 = fluid.layers.fc(input=fc2,
size=num_classes, size=num_classes,
act=classifier_activation,
param_attr=fc_w3_attr, param_attr=fc_w3_attr,
bias_attr=fc_b3_attr) bias_attr=fc_b3_attr)
fc4 = fluid.layers.softmax(fc3, use_cudnn=True)
return fc3 return fc4
class ImperativeLenet(fluid.dygraph.Layer): class ImperativeLenet(fluid.dygraph.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'): def __init__(self, num_classes=10):
super(ImperativeLenet, self).__init__() super(ImperativeLenet, self).__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
...@@ -107,47 +107,46 @@ class ImperativeLenet(fluid.dygraph.Layer): ...@@ -107,47 +107,46 @@ class ImperativeLenet(fluid.dygraph.Layer):
fc_b3_attr = fluid.ParamAttr(name="fc_b_3") fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
self.features = Sequential( self.features = Sequential(
Conv2D( Conv2D(
num_channels=1, in_channels=1,
num_filters=6, out_channels=6,
filter_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
param_attr=conv2d_w1_attr, weight_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr), bias_attr=conv2d_b1_attr),
Pool2D( Pool2D(
pool_size=2, pool_type='max', pool_stride=2), pool_size=2, pool_type='max', pool_stride=2),
Conv2D( Conv2D(
num_channels=6, in_channels=6,
num_filters=16, out_channels=16,
filter_size=5, kernel_size=5,
stride=1, stride=1,
padding=0, padding=0,
param_attr=conv2d_w2_attr, weight_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr), bias_attr=conv2d_b2_attr),
Pool2D( Pool2D(
pool_size=2, pool_type='max', pool_stride=2)) pool_size=2, pool_type='max', pool_stride=2))
self.fc = Sequential( self.fc = Sequential(
Linear( Linear(
input_dim=400, in_features=400,
output_dim=120, out_features=120,
param_attr=fc_w1_attr, weight_attr=fc_w1_attr,
bias_attr=fc_b1_attr), bias_attr=fc_b1_attr),
Linear( Linear(
input_dim=120, in_features=120,
output_dim=84, out_features=84,
param_attr=fc_w2_attr, weight_attr=fc_w2_attr,
bias_attr=fc_b2_attr), bias_attr=fc_b2_attr),
Linear( Linear(
input_dim=84, in_features=84,
output_dim=num_classes, out_features=num_classes,
act=classifier_activation, weight_attr=fc_w3_attr,
param_attr=fc_w3_attr, bias_attr=fc_b3_attr),
bias_attr=fc_b3_attr)) Softmax())
def forward(self, inputs): def forward(self, inputs):
x = self.features(inputs) x = self.features(inputs)
x = fluid.layers.flatten(x, 1) x = fluid.layers.flatten(x, 1)
x = self.fc(x) x = self.fc(x)
return x return x
...@@ -162,8 +161,19 @@ class TestImperativeQat(unittest.TestCase): ...@@ -162,8 +161,19 @@ class TestImperativeQat(unittest.TestCase):
imperative_qat = ImperativeQuantAware( imperative_qat = ImperativeQuantAware(
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max') activation_quantize_type='moving_average_abs_max')
with fluid.dygraph.guard(): with fluid.dygraph.guard():
# For CI coverage
conv1 = Conv2D(
in_channels=3,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
padding_mode='replicate')
quant_conv1 = QuantizedConv2D(conv1)
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
quant_conv1(fluid.dygraph.to_variable(data))
lenet = ImperativeLenet() lenet = ImperativeLenet()
imperative_qat.quantize(lenet) imperative_qat.quantize(lenet)
adam = AdamOptimizer( adam = AdamOptimizer(
...@@ -286,7 +296,7 @@ class TestImperativeQat(unittest.TestCase): ...@@ -286,7 +296,7 @@ class TestImperativeQat(unittest.TestCase):
activation_quant_type = 'moving_average_abs_max' activation_quant_type = 'moving_average_abs_max'
param_init_map = {} param_init_map = {}
seed = 1000 seed = 1000
lr = 0.1 lr = 0.01
# imperative train # imperative train
_logger.info( _logger.info(
......
...@@ -27,9 +27,8 @@ from paddle.fluid.framework import IrGraph ...@@ -27,9 +27,8 @@ from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.nn import Conv2D from paddle.nn import Linear, Conv2D, Softmax
from paddle.fluid.dygraph.nn import Pool2D from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
...@@ -43,7 +42,7 @@ _logger = get_logger( ...@@ -43,7 +42,7 @@ _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def StaticLenet(data, num_classes=10, classifier_activation='softmax'): def StaticLenet(data, num_classes=10):
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1") fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
...@@ -85,15 +84,15 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'): ...@@ -85,15 +84,15 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
bias_attr=fc_b2_attr) bias_attr=fc_b2_attr)
fc3 = fluid.layers.fc(input=fc2, fc3 = fluid.layers.fc(input=fc2,
size=num_classes, size=num_classes,
act=classifier_activation,
param_attr=fc_w3_attr, param_attr=fc_w3_attr,
bias_attr=fc_b3_attr) bias_attr=fc_b3_attr)
fc4 = fluid.layers.softmax(fc3, use_cudnn=True)
return fc3 return fc4
class ImperativeLenet(fluid.dygraph.Layer): class ImperativeLenet(fluid.dygraph.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'): def __init__(self, num_classes=10):
super(ImperativeLenet, self).__init__() super(ImperativeLenet, self).__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
...@@ -107,53 +106,52 @@ class ImperativeLenet(fluid.dygraph.Layer): ...@@ -107,53 +106,52 @@ class ImperativeLenet(fluid.dygraph.Layer):
fc_b3_attr = fluid.ParamAttr(name="fc_b_3") fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
self.features = Sequential( self.features = Sequential(
Conv2D( Conv2D(
num_channels=1, in_channels=1,
num_filters=6, out_channels=6,
filter_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
param_attr=conv2d_w1_attr, weight_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr), bias_attr=conv2d_b1_attr),
Pool2D( Pool2D(
pool_size=2, pool_type='max', pool_stride=2), pool_size=2, pool_type='max', pool_stride=2),
Conv2D( Conv2D(
num_channels=6, in_channels=6,
num_filters=16, out_channels=16,
filter_size=5, kernel_size=5,
stride=1, stride=1,
padding=0, padding=0,
param_attr=conv2d_w2_attr, weight_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr), bias_attr=conv2d_b2_attr),
Pool2D( Pool2D(
pool_size=2, pool_type='max', pool_stride=2)) pool_size=2, pool_type='max', pool_stride=2))
self.fc = Sequential( self.fc = Sequential(
Linear( Linear(
input_dim=400, in_features=400,
output_dim=120, out_features=120,
param_attr=fc_w1_attr, weight_attr=fc_w1_attr,
bias_attr=fc_b1_attr), bias_attr=fc_b1_attr),
Linear( Linear(
input_dim=120, in_features=120,
output_dim=84, out_features=84,
param_attr=fc_w2_attr, weight_attr=fc_w2_attr,
bias_attr=fc_b2_attr), bias_attr=fc_b2_attr),
Linear( Linear(
input_dim=84, in_features=84,
output_dim=num_classes, out_features=num_classes,
act=classifier_activation, weight_attr=fc_w3_attr,
param_attr=fc_w3_attr, bias_attr=fc_b3_attr),
bias_attr=fc_b3_attr)) Softmax())
def forward(self, inputs): def forward(self, inputs):
x = self.features(inputs) x = self.features(inputs)
x = fluid.layers.flatten(x, 1) x = fluid.layers.flatten(x, 1)
x = self.fc(x) x = self.fc(x)
return x return x
class TestImperativeQat(unittest.TestCase): class TestImperativeQatChannelWise(unittest.TestCase):
""" """
QAT = quantization-aware training QAT = quantization-aware training
""" """
...@@ -286,7 +284,7 @@ class TestImperativeQat(unittest.TestCase): ...@@ -286,7 +284,7 @@ class TestImperativeQat(unittest.TestCase):
activation_quant_type = 'moving_average_abs_max' activation_quant_type = 'moving_average_abs_max'
param_init_map = {} param_init_map = {}
seed = 1000 seed = 1000
lr = 0.1 lr = 0.001
# imperative train # imperative train
_logger.info( _logger.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册