提交 c4750264 编写于 作者: Z Zhang, Guoming

enable conv/sum fusion

from __future__ import absolute_import
from __future__ import division
# from __future__ import print_function
import os
import numpy as np
import time
import sys
import paddle
import paddle.fluid as fluid
import models
import reader
import argparse
import functools
from models.learning_rate import cosine_decay
from utility import add_arguments, print_arguments
import math
import paddle.fluid.core as core
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('class_dim', int, 1000, "Class number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def get_quantization_op_pos(program):
conv_op_index = [index for index, value in enumerate(program.global_block().ops) if value.type == 'conv2d']
if len(conv_op_index) < 2:
return None
return [conv_op_index[1]]
def get_dequantization_op_pos(program):
conv_op_index = [index for index, value in enumerate(program.global_block().ops) if value.type == 'conv2d']
if len(conv_op_index) < 2:
return None
res = []
support_int8_op_type = ["pool2d"]
for index, value in enumerate(conv_op_index[:-1]):
if index == 0: continue
if value + 1 == conv_op_index[index + 1]:
continue
else:
start_index = index + 1
end_index = conv_op_index[index + 1]
while start_index < end_index:
if program.global_block().ops[start_index].type not in support_int8_op_type:
res.append(start_index)
break
else:
start_index += 1
res.append(conv_op_index[-1]) #need to fix
return res
def get_requantization_op_pos(program):
pass
# def create_op(program, op_name, data_type):
def eval(args):
# parameters from arguments
class_dim = args.class_dim
model_name = args.model
pretrained_model = args.pretrained_model
with_memory_optimization = args.with_mem_opt
image_shape = [int(m) for m in args.image_shape.split(",")]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[model_name]()
if model_name is "GoogleNet":
out0, out1, out2 = model.net(input=image, class_dim=class_dim)
cost0 = fluid.layers.cross_entropy(input=out0, label=label)
cost1 = fluid.layers.cross_entropy(input=out1, label=label)
cost2 = fluid.layers.cross_entropy(input=out2, label=label)
avg_cost0 = fluid.layers.mean(x=cost0)
avg_cost1 = fluid.layers.mean(x=cost1)
avg_cost2 = fluid.layers.mean(x=cost2)
avg_cost = avg_cost0 + 0.3 * avg_cost1 + 0.3 * avg_cost2
acc_top1 = fluid.layers.accuracy(input=out0, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out0, label=label, k=5)
else:
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
test_program = fluid.default_main_program().clone(for_test=True)
if with_memory_optimization:
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
print 120, pretrained_model
t = fluid.transpiler.InferenceTranspiler()
t.transpile(test_program, fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())
# for i in test_program.current_block().ops:
# print i
# sys.exit(0)
conv_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'conv2d']
print (conv_op_index)
weights_var_name = []
conv_input_var_name = []
conv_output_var_name = []
weights_channel = {}
for i in conv_op_index[1:]:
weights_var_name.append(test_program.current_block().ops[i].input('Filter')[0])
conv_input_var_name.append(test_program.current_block().ops[i].input('Input')[0])
conv_output_var_name.append(test_program.current_block().ops[i].output('Output')[0])
for i in test_program.list_vars():
if i.name in weights_var_name:
weights_channel[i.name] = i.shape[0]
# print weights_var_name
# print '-------'
# print conv_input_var_name
# print '-------'
# print conv_output_var_name
# for i in test_program.current_block().ops:
# print ('-----------')
# print (i.input_names, i.output_names)
# if i.type == 'conv2d':
# print i.input('Filter')
# print (i.input_arg_names)
# print (i.output_arg_names)
# # print (i.block_attr)
# print (dir(i))
# print (i.attr_names)
# print ((i.attr))
# for j in i.attr_names:
# print ((i.attr(j)))
# print (i.blocks_attr)
# sys.exit(0)
# for i in test_program.list_vars():
# print (i.name)
# # print dir(i)
# print i.shape, i.type, i.dtype
# if i.name == "batch_norm_52.b_0_fuse_bn":
# i.dtype = fluid.core.VarDesc.VarType.INT8;
# print (test_program.global_block().ops[23].type)
# for i in conv_op_index:
# op = test_program.current_block().ops[i]
# print (op)
# print (op.input_names, op.input_arg_names, op.output_arg_names)
not_persistable_vars = (i for i in test_program.list_vars() if not i.persistable)
for i in not_persistable_vars:
# # print (i.name, i.persistable)
i.persistable= True
# int8_prog = test_program.clone()
var_name = [i.name for i in test_program.list_vars()]
# get_dequantization_op_pos(int8_prog)
# print var_name
# sys.exit(0)
val_reader = paddle.batch(reader.val(), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
test_info = [[], [], []]
cnt = 0
var_max = {}
for batch_id, data in enumerate(val_reader()):
t1 = time.time()
loss, acc1, acc5 = exe.run(test_program,
fetch_list=fetch_list,
feed=feeder.feed(data))
for i in var_name:
# print (np.array(fluid.global_scope().find_var(i).get_tensor()).shape)
np_data = np.array(fluid.global_scope().find_var(i).get_tensor())
if i in weights_var_name:
max_value = [float(np.amax(np_data[j])) for j in range(np_data.shape[0])]
else:
max_value = [float(np.amax(np_data))]
var_max[i] = []
var_max[i].append(max_value)
# print var_max
t2 = time.time()
period = t2 - t1
loss = np.mean(loss)
acc1 = np.mean(acc1)
acc5 = np.mean(acc5)
test_info[0].append(loss * len(data))
test_info[1].append(acc1 * len(data))
test_info[2].append(acc5 * len(data))
cnt += len(data)
if batch_id % 10 == 0:
print("Testbatch {0},loss {1}, "
"acc1 {2},acc5 {3},time {4}".format(batch_id, \
loss, acc1, acc5, \
"%2.2f sec" % period))
sys.stdout.flush()
break
test_loss = np.sum(test_info[0]) / cnt
test_acc1 = np.sum(test_info[1]) / cnt
test_acc5 = np.sum(test_info[2]) / cnt
print("Test_loss {0}, test_acc1 {1}, test_acc5 {2}".format(
test_loss, test_acc1, test_acc5))
sys.stdout.flush()
#insert quantization op
infer_prog = test_program.clone()
pos = get_quantization_op_pos(infer_prog)
print pos
print infer_prog.current_block().ops[1].output('Out')[0]
conv2_scale_in = infer_prog.global_block().create_var(
name="conv2_scale_in",
dtype="float32",
persistable=True,
)
# conv2_weights_in = infer_prog.global_block().create_var(
# name="conv2_weights_in",
# dtype="float32",
# persistable=True,
# )
conv2_int8_tmp = infer_prog.global_block().create_var(
name="conv2_int8_tmp",
dtype="int8",
persistable=True,
shape= (np.array(fluid.global_scope().find_var('pool2d_0.tmp_0').get_tensor())).shape
)
# print ((np.array(fluid.global_scope().find_var('pool2d_0.tmp_0').get_tensor())).shape)
# sys.exit(0)
# fluid.initializer.Constant(value=1.0)(conv2_int8_tmp, infer_prog.global_block())
infer_prog.current_block().append_op(
type='assign_value',
outputs={'Out': [conv2_scale_in]},
attrs={
'dtype':core.VarDesc.VarType.FP32,
'shape': [1,1],
'fp32_values': var_max[var_name[1]][0]
}
)
# infer_prog.current_block().append_op(
# type='assign_value',
# outputs={'Out': [conv2_int8_tmp]},
# attrs={
# 'dtype':core.VarDesc.VarType.UINT8,
# 'shape': (np.array(fluid.global_scope().find_var('pool2d_0.tmp_0').get_tensor())).shape,
# # 'fp32_values': var_max[var_name[1]][0]
# }
# )
# op = infer_prog.current_block()._insert_op(
# index=pos[0],
# type= "quantize",
# inputs={"Input": infer_prog.current_block().ops[1].output('Out')[0],
# "Scale": conv2_scale_in},
# outputs={"Output":conv2_int8_tmp},
# # attrs= {
# # "data_format": "NCHW"
# # }
# )
# op.set_attr("data_format", "NCHW")
# op.set_attr("use_mkldnn", 1)
# infer_prog.current_block().ops[3].set_input("Input", ['conv2_int8_tmp'])
# infer_prog.current_block().append_op(
# type='assign_value',
# outputs={'Out': [conv2_weights_in]},
# attrs={
# 'dtype':core.VarDesc.VarType.FP32,
# 'shape': [1,1],
# 'fp32_values': [3.12]
# }
# )
# for i in infer_prog.current_block().ops[:4]:
# print (i)
# sys.exit(0)
# with open("/home/guomingz/__model_xiaoli_quantize__", "wb") as f:
# f.write(infer_prog.desc.serialize_to_string())
infer_prog.current_block().append_op(
type = 'save',
inputs={'X': 'conv2_scale_in'},
outputs={},
attrs={"file_path": "{}/conv2_scale_in".format(pretrained_model)}
)
# infer_prog.current_block().append_op(
# type = 'save',
# inputs={'X': 'conv2_int8_tmp'},
# outputs={},
# attrs={"file_path": "{}/conv2_int8_tmp".format(pretrained_model)}
# )
# val_reader = paddle.batch(reader.val(), batch_size=args.batch_size)
for batch_id, data in enumerate(val_reader()):
# print (feeder.feed(data))
# print (fetch_list)
loss, acc1, acc5 = exe.run(infer_prog,
fetch_list=fetch_list,
feed=feeder.feed(data))
sys.exit(0)
# infer_prog.current_block().append_op(
# type = 'save',
# inputs={'X': 'conv2_weights_in'},
# outputs={},
# attrs={"file_path": "{}/conv2_weights_in".format(pretrained_model)}
# )
#insert dequantization op
#rerun to save variable
# for batch_id, data in enumerate(val_reader()):
# t1 = time.time()
# loss, acc1, acc5 = exe.run(test_program,
# fetch_list=fetch_list,
# feed=feeder.feed(data))
# with open("/home/guomingz/__model__", "wb") as f:
# f.write(test_program.desc.serialize_to_string())
def main():
args = parser.parse_args()
print_arguments(args)
eval(args)
if __name__ == '__main__':
main()
......@@ -131,21 +131,29 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
bool is_persistent = false,
bool is_INT8 = false,
std::vector<float> scale_data = {1.0f},
int mask = 0) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p",
pipeline, is_persistent);
pipeline, is_persistent,
is_INT8, scale_data, mask);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
std::vector<mkldnn::primitive>& pipeline,
bool is_INT8 = false,
std::vector<float> scale_data = {1.0f},
int mask = 0) { // NOLINT
auto user_bias_pd = user_bias_memory_p->get_primitive_desc();
auto bias_pd = conv_pd_->bias_primitive_desc();
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
"@bias_mem_p", pipeline);
"@bias_mem_p", pipeline,
false, is_INT8, scale_data, mask);
}
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
......@@ -278,6 +286,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
bool is_INT8 = ctx.HasInput("Scale_in")? true : false;
auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr;
auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr;
auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr;
auto* scale_out = ctx.HasInput("Scale_out")? ctx.Input<Tensor>("Scale_out") : nullptr;
bool is_multi_channel = (is_INT8 && scale_weights->memory_size() > 1) ? true : false;
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
......@@ -329,6 +344,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
std::vector<T> output_shift_scale;
T sum_scale = 1.0f;
if(is_INT8){
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
T scale_in_data = *(scale_in->data<T>());
T scale_in_eltwise_data = *(scale_in_eltwise->data<T>());
std::vector<T> scale_weights_data(count);
for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<T>() + i);
}
T scale_out_data = *(scale_out->data<T>());
output_shift_scale.resize(count);
for(int i=0; i<count; i++){
if(scale_weights_data[i] == 0.0)
output_shift_scale[i] = scale_out_data;
else
output_shift_scale[i] = scale_out_data / (scale_in_data * scale_weights_data[i]);
}
sum_scale = scale_out_data / scale_in_eltwise_data;
}
// Get unique name for storing MKLDNN primitives
const std::string key = ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
......@@ -367,13 +405,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn);
if(is_INT8){
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale);
} else{
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn);
}
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn);
if(is_INT8){
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale);
} else{
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn);
}
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -411,6 +463,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test);
if(is_INT8){
int mask_reorder = is_multi_channel? 0 : ((g!= 1) ? (1<<1)+(1<<0) : 1<<0);
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
std::vector<T> scale_weights_data(count);
for(int i=0; i<count; i++){
scale_weights_data[i] = *(scale_weights->data<T>() + i);
}
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder);
}
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
......@@ -422,9 +484,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
if(is_INT8){
int mask_reorder = is_multi_channel? 0 : 1<<0;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
std::vector<T> scale_bias_data(count);
for(int i=0; i<count; i++){
scale_bias_data[i] = (*scale_in->data<T>()) * (*(scale_weights->data<T>() + i));
}
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder);
}
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p);
} else {
......@@ -441,79 +512,154 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_residual_conn) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,
const std::vector<T> output_shift_scale, T sum_scale) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(1.0f);
int mask = output_shift_scale.size() > 1 ? 1<<1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f; //beta
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_eltwise) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_eltwise is true, the
// Output tensor contains the data coming from residual connection. The
// result of this post_op is: Output = scale * Output + Conv_Out.
if (fuse_eltwise) {
post_operations.append_sum(1.0f);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
const std::vector<T> output_shift_scale, const T sum_scale) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn) const{
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
const std::vector<T> output_shift_scale, const T sum_scale) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn) const{
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
};
template <typename T>
......
......@@ -128,6 +128,21 @@ void Conv2DOpMaker::Make() {
"The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN.")
.AsDispensable();
AddInput("Scale_in",
"(Tensor) Scale_in to be used for int8 input data. Only used with INT8.")
.AsDispensable();
AddInput("Scale_in_eltwise",
"(Tensor) Scale_in_eltwise to be used for int8 eltwise input data."
"Only used with MKL-DNN.")
.AsDispensable();
AddInput("Scale_weights",
"(Tensor) Scale_weights to be used for int8 weights data."
"Only used with MKL-DNN.")
.AsDispensable();
AddInput("Scale_out",
"(Tensor) Scale_out to be used for int8 output data."
"Only used with MKL-DNN.")
.AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.")
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/dequantize_op.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle {
namespace operators {
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
//using MKLDNNDataType = mkldnn::memory::data_type;
template <typename DeviceContext, typename T>
class DeQuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* scale = ctx.Input<Tensor>("Scale");
auto* output = ctx.Output<Tensor>("Output");
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
//T scale_data = *(scale->data<T>());
std::vector<T> scale_data = {*(scale->data<T>())};
std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::format src_fmt = memory::format::nhwc;//input->format();
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, scale_data);
auto src_md = platform::MKLDNNMemDesc(
{src_tz}, src_dt, src_fmt);
auto src_pd = mkldnn::memory::primitive_desc{src_md, engine};
auto src_memory = std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::f32, memory::format::nchw);
auto dst_pd = mkldnn::memory::primitive_desc{dst_md, engine};
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<T>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(dst_pd, src_pd, attri));
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p);
}
};
framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
}
void DeQuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from INT8 to FP32
)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(dequantize, ops::DeQuantOpKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class DeQuantOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class DeQuantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
class DeQuantGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/quantize_op.h"
namespace paddle {
namespace operators {
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
template <typename DeviceContext, typename T>
class QuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* scale = ctx.Input<Tensor>("Scale");
auto* output = ctx.Output<Tensor>("Output");
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<T> scale_data = {*(scale->data<T>())};
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, scale_data);
auto src_md = platform::MKLDNNMemDesc(
{src_tz}, memory::data_type::f32, input->format());
auto src_pd = mkldnn::memory::primitive_desc{src_md, engine};
auto src_memory = std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::u8, memory::format::nhwc);
auto dst_pd = mkldnn::memory::primitive_desc{dst_md, engine};
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<T>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(dst_pd, src_pd, attri));
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p);
}
};
framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
//ctx.device_context());
}
void QuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from FP32 to INT8
)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(quantize, ops::QuantOpKernel<paddle::platform::CPUDeviceContext, float>);
//REGISTER_OP_KERNEL(quantization, MKLDNN, paddle::platform::CPUPlace, ops::QuantOpKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class QuantOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override{}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class QuantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
//void Make() {
// AddInput("Input","input");
// AddInput("Scale","scale...");
// AddOutput("Output","output");
//}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/requantize_op.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle {
namespace operators {
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
template <typename DeviceContext, typename T>
class ReQuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* scale = ctx.Input<Tensor>("Scale");
auto* output = ctx.Output<Tensor>("Output");
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::data_type dst_dt = paddle::framework::ToMKLDNNDataType(output->type());
mkldnn::memory::format src_fmt = memory::format::nhwc;//input->format();
mkldnn::memory::format dst_fmt = memory::format::nhwc;//output->format();
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
//T scale_data = *(scale->data<T>());
std::vector<T> scale_data = {*(scale->data<T>())};
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, scale_data);
//attri.set_int_output_round_mode(round_nearest); //FIX ME
auto src_md = platform::MKLDNNMemDesc(
{src_tz}, src_dt, src_fmt); //FIX ME WITH S8
auto src_pd = mkldnn::memory::primitive_desc{src_md, engine};
auto src_memory = std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, dst_dt, dst_fmt);
auto dst_pd = mkldnn::memory::primitive_desc{dst_md, engine};
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<T>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(dst_pd, src_pd, attri));
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p);
}
};
framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
}
void ReQuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will requantize data from INT8 to INT8
)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(requantize, ops::ReQuantOpKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class ReQuantOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class ReQuantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
......@@ -193,7 +193,10 @@ class MKLDNNHandler {
const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
bool is_persistent = false,
bool is_INT8 = false,
std::vector<float> scale_data = {1.0f},
int mask = 0) {
// create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
......@@ -207,9 +210,17 @@ class MKLDNNHandler {
std::shared_ptr<mkldnn::primitive> reorder_p;
if (mpd != user_mpd) {
target_memory_p = std::make_shared<mkldnn::memory>(mpd);
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
if(is_INT8){
mkldnn::primitive_attr attri;
attri.set_output_scales(mask, scale_data);
auto reorder_pd = std::shared_ptr<mkldnn::reorder::primitive_desc>(
new mkldnn::reorder::primitive_desc(mpd, user_mpd, attri));
auto reorder_p =
std::shared_ptr<mkldnn::reorder>(new mkldnn::reorder(*reorder_pd, *user_memory_p, *target_memory_p));
}
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
pipeline.push_back(*reorder_p);
}
......
......@@ -657,7 +657,10 @@ class Operator(object):
def _has_kernel(self, op_type):
return op_type not in self.OP_WITHOUT_KERNEL_SET
def set_input(self, name, value):
self.desc.set_input(name, value)
def to_string(self, throw_on_error):
"""
Get debug string.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册