未验证 提交 aacba6f5 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU] Add fusion_elementwise_add_activation bridge, remove the...

[LITE][NPU] Add fusion_elementwise_add_activation bridge, remove the limitation of the dimensions of input tensors in the graph compute kernel, and refine log message (#2395)

test=develop
上级 f51c4891
......@@ -67,7 +67,7 @@ std::string UniqueName(const std::string& prefix) {
return prefix + "_" + std::to_string(counter);
}
ge::DataType PrecisionConverter(PrecisionType itype) {
ge::DataType CvtPrecisionType(PrecisionType itype) {
ge::DataType otype = ge::DT_FLOAT;
switch (itype) {
case PRECISION(kFloat):
......@@ -80,14 +80,14 @@ ge::DataType PrecisionConverter(PrecisionType itype) {
otype = ge::DT_INT32;
break;
default:
LOG(FATAL) << "Can not convert precision type(" << PrecisionToStr(itype)
<< ") from Lite to NPU";
LOG(FATAL) << "[NPU] Can not convert precision type("
<< PrecisionToStr(itype) << ") from Lite to NPU";
break;
}
return otype;
}
ge::Format DataLayoutConverter(DataLayoutType itype) {
ge::Format CvtDataLayoutType(DataLayoutType itype) {
ge::Format otype = ge::FORMAT_NCHW;
switch (itype) {
case DATALAYOUT(kNCHW):
......@@ -95,17 +95,17 @@ ge::Format DataLayoutConverter(DataLayoutType itype) {
break;
// TODO(hong19860320) support more data layout type
default:
LOG(FATAL) << "Can not convert data layout type("
LOG(FATAL) << "[NPU] Can not convert data layout type("
<< DataLayoutToStr(itype) << ") from Lite to NPU";
break;
}
return otype;
}
ge::TensorPtr CvtFromLiteTensor(lite::Tensor* in_tensor,
std::vector<int64_t> out_shape,
PrecisionType in_ptype,
DataLayoutType in_ltype) {
ge::TensorPtr CvtTensor(lite::Tensor* in_tensor,
std::vector<int64_t> out_shape,
PrecisionType in_ptype,
DataLayoutType in_ltype) {
uint8_t* in_data = nullptr;
auto in_size = in_tensor->dims().production();
auto in_shape = in_tensor->dims().Vectorize();
......@@ -123,10 +123,10 @@ ge::TensorPtr CvtFromLiteTensor(lite::Tensor* in_tensor,
in_data = reinterpret_cast<uint8_t*>(in_tensor->mutable_data<int8_t>());
in_bytes = in_size * sizeof(int8_t);
} else {
LOG(FATAL) << "Unknow precision type " << PrecisionToStr(in_ptype);
LOG(FATAL) << "[NPU] Unknow precision type " << PrecisionToStr(in_ptype);
}
ge::DataType out_ptype = PrecisionConverter(in_ptype);
ge::Format out_ltype = DataLayoutConverter(in_ltype);
ge::DataType out_ptype = CvtPrecisionType(in_ptype);
ge::Format out_ltype = CvtDataLayoutType(in_ltype);
ge::TensorDesc out_desc(ge::Shape(out_shape), out_ltype, out_ptype);
CHECK_EQ(out_ltype, ge::FORMAT_NCHW);
......@@ -140,6 +140,31 @@ ge::TensorPtr CvtFromLiteTensor(lite::Tensor* in_tensor,
return out_tensor;
}
int CvtActMode(std::string act_type) {
int act_mode = 1;
if (act_type == "sigmod") {
act_mode = 0;
} else if (act_type == "relu") {
act_mode = 1;
} else if (act_type == "tanh") {
act_mode = 2;
} else if (act_type == "elu") {
act_mode = 4;
} else if (act_type == "abs") {
act_mode = 6;
} else if (act_type == "softsign") {
act_mode = 8;
} else if (act_type == "softplus") {
act_mode = 9;
} else if (act_type == "hardsigmoid") {
act_mode = 10;
} else {
// TODO(hong19860320) support more activation mode
LOG(FATAL) << "[NPU] Unsupported activation type " << act_type;
}
return act_mode;
}
bool HasInputArg(const OpInfo* op_info,
const Scope* scope,
const std::string& argname) {
......
......@@ -192,14 +192,14 @@ bool BuildModel(std::vector<ge::Operator>& inputs, // NOLINT
std::string UniqueName(const std::string& prefix);
ge::DataType PrecisionConverter(PrecisionType itype);
ge::DataType CvtPrecisionType(PrecisionType itype);
ge::Format DataLayoutConverter(DataLayoutType itype);
ge::Format CvtDataLayoutType(DataLayoutType itype);
ge::TensorPtr CvtFromLiteTensor(Tensor* in_tensor,
std::vector<int64_t> out_shape = {},
PrecisionType in_ptype = PRECISION(kFloat),
DataLayoutType in_ltype = DATALAYOUT(kNCHW));
ge::TensorPtr CvtTensor(Tensor* in_tensor,
std::vector<int64_t> out_shape = {},
PrecisionType in_ptype = PRECISION(kFloat),
DataLayoutType in_ltype = DATALAYOUT(kNCHW));
template <typename T>
ge::TensorPtr CreateTensorAndFillData(std::vector<T> data,
......@@ -214,7 +214,7 @@ ge::TensorPtr CreateTensorAndFillData(std::vector<T> data,
} else if (info == typeid(int32_t)) {
type = ge::DT_INT32;
} else {
LOG(FATAL) << "Unknow value type " << info.name();
LOG(FATAL) << "[NPU] Unknow value type " << info.name();
}
if (shape.empty()) {
shape = {static_cast<int64_t>(data.size())};
......@@ -245,6 +245,8 @@ ge::TensorPtr CreateTensorAndFillData(T value,
return CreateTensorAndFillData(data, shape, format);
}
int CvtActMode(std::string act_type);
bool HasInputArg(const OpInfo* op_info,
const Scope* scope,
const std::string& argname);
......
......@@ -35,7 +35,7 @@ std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode(
lite::mir::Node* var_node, const Scope* scope) {
CHECK(var_node->IsArg());
const auto& arg = var_node->AsArg();
VLOG(4) << "Convert var node " << arg.name;
VLOG(4) << "[NPU] Convert var node " << arg.name;
auto* var = scope->FindVar(arg.name);
CHECK(var);
......@@ -44,13 +44,13 @@ std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode(
auto dims = tensor->dims();
if (arg.is_weight) {
auto wgt = std::make_shared<ge::op::Const>(arg.name);
LOG(INFO) << "in convert const:" << arg.name;
LOG(INFO) << "[NPU] Convert const var node " << arg.name;
VLOG(4) << dims;
wgt->set_attr_value(lite::npu::CvtFromLiteTensor(tensor));
wgt->set_attr_value(lite::npu::CvtTensor(tensor));
return wgt;
} else {
CHECK_EQ(dims.size(), 4);
LOG(INFO) << "in convert data:" << arg.name;
LOG(INFO) << "[NPU] Convert data var node " << arg.name;
LOG(INFO) << dims;
// TODO(xxx): support more types and dims size
ge::TensorDesc desc(ge::Shape(dims.Vectorize()),
......@@ -128,10 +128,10 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
if (!lite::npu::BuildModel(inputs, outputs, weight)) {
LOG(WARNING) << "Build NPU failed subgraph " << sub_id;
throw std::runtime_error("Build NPU failed subgraph.");
LOG(WARNING) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")";
throw std::runtime_error("Build NPU graph failed.");
}
LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id;
LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
return weight_var_name;
}
......@@ -166,12 +166,12 @@ void GenerateNPUProgramPass::GenNPUSubgraph(
}
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get());
LOG(INFO) << "[NPU] Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::kernels::npu::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
for (auto& i : op_map) {
LOG(INFO) << "Supported type: " << i.first;
LOG(INFO) << "[NPU] Supported type: " << i.first;
supported_op_types.push_back(i.first);
}
......@@ -182,15 +182,15 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "Converting subgraph_id:" << id;
LOG(INFO) << "[NPU] Converting Subgraph " << id;
GenNPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "After NPU Pass Subgraph " << id << "\n"
LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
} catch (...) {
LOG(WARNING) << "Build NPU graph failed";
throw std::runtime_error("Build NPU graph failed");
LOG(WARNING) << "[NPU] Build NPU graph failed.";
throw std::runtime_error("[NPU] Build NPU graph failed.");
}
for (auto& item : graph->StmtTopologicalOrder()) {
......@@ -203,7 +203,7 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
LOG(INFO) << "insts.size " << insts_.size();
LOG(INFO) << "[NPU] program insts.size " << insts_.size();
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
......
......@@ -27,7 +27,7 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
auto op_info = act_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
// create act node and set input node from inputs_map
auto x_var_name = op_info->Input("X").front();
......@@ -37,30 +37,9 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(act_node);
// parse and set activation type
int act_mode = 1;
if (op_type == "sigmod") {
act_mode = 0;
} else if (op_type == "relu") {
act_mode = 1;
} else if (op_type == "tanh") {
act_mode = 2;
} else if (op_type == "elu") {
act_mode = 4;
} else if (op_type == "abs") {
act_mode = 6;
} else if (op_type == "softsign") {
act_mode = 8;
} else if (op_type == "softplus") {
act_mode = 9;
} else if (op_type == "hardsigmoid") {
act_mode = 10;
} else {
// TODO(hong19860320) add more activation mode, and set the coef value
// clipped ReLU, LEAKY_RELU, relu1, threshold, selu and linear
LOG(FATAL) << "Unsupported activation type " << op_type;
}
act_node->set_attr_mode(act_mode);
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
act_node->set_attr_mode(lite::npu::CvtActMode(op_type));
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = act_node;
......
......@@ -28,7 +28,7 @@ node_map_type BatchNormConverter(
auto op_info = batch_norm_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::BatchNorm> batch_norm_node =
std::make_shared<ge::op::BatchNorm>(unique_op_type);
......@@ -37,26 +37,26 @@ node_map_type BatchNormConverter(
auto scale_var_name = op_info->Input("Scale").front();
lite::Tensor* scale = scope->FindVar(scale_var_name)->GetMutable<Tensor>();
auto npu_scale = std::make_shared<ge::op::Const>(scale_var_name);
npu_scale->set_attr_value(lite::npu::CvtFromLiteTensor(scale));
npu_scale->set_attr_value(lite::npu::CvtTensor(scale));
lite::npu::OpList::Global().add(npu_scale);
auto bias_var_name = op_info->Input("Bias").front();
lite::Tensor* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto npu_bias = std::make_shared<ge::op::Const>(bias_var_name);
npu_bias->set_attr_value(lite::npu::CvtFromLiteTensor(bias));
npu_bias->set_attr_value(lite::npu::CvtTensor(bias));
lite::npu::OpList::Global().add(npu_bias);
auto mean_var_name = op_info->Input("Mean").front();
lite::Tensor* mean = scope->FindVar(mean_var_name)->GetMutable<Tensor>();
auto npu_mean = std::make_shared<ge::op::Const>(mean_var_name);
npu_mean->set_attr_value(lite::npu::CvtFromLiteTensor(mean));
npu_mean->set_attr_value(lite::npu::CvtTensor(mean));
lite::npu::OpList::Global().add(npu_mean);
auto variance_var_name = op_info->Input("Variance").front();
lite::Tensor* variance =
scope->FindVar(variance_var_name)->GetMutable<Tensor>();
auto npu_variance = std::make_shared<ge::op::Const>(variance_var_name);
npu_variance->set_attr_value(lite::npu::CvtFromLiteTensor(variance));
npu_variance->set_attr_value(lite::npu::CvtTensor(variance));
lite::npu::OpList::Global().add(npu_variance);
float npu_momentum = op_info->GetAttr<float>("momentum");
......
......@@ -27,7 +27,7 @@ node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op,
const lite::OpInfo* op_info = concat_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "converting " << op_type << " ... ";
LOG(INFO) << "[NPU] Converting " << op_type << " ... ";
auto x_var_names = op_info->Input("X");
auto axis = op_info->GetAttr<int>("axis");
......@@ -46,7 +46,7 @@ node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op,
} else {
auto consty = std::make_shared<ge::op::Const>(x_var_name);
auto* x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
consty->set_attr_value(lite::npu::CvtFromLiteTensor(x));
consty->set_attr_value(lite::npu::CvtTensor(x));
output_node->set_dynamic_input_x(index + 1, *consty);
lite::npu::OpList::Global().add(consty);
}
......
......@@ -27,7 +27,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto op_info = conv_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << "... ";
LOG(INFO) << "[NPU] Converting " << op_type << "... ";
// get input, filter and op attributes
auto input_var_name = op_info->Input("Input").front();
......@@ -64,10 +64,10 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
!((groups == 1 || groups >= 5) && dilations[0] == 1 &&
dilations[1] == 1)) {
use_depthwise_conv = true;
LOG(WARNING) << "For depthwise mode, dilation = 1 and groups >= 5 (or "
"groups = 1) is only supported in "
"Convolution Op, so force to use ConvolutionDepthwise Op, "
"but may lead poor performance.";
LOG(WARNING) << "[NPU] For depthwise mode, dilation = 1 and groups >= 5 "
"(or groups = 1) is only supported in Convolution Op, so "
"force to use ConvolutionDepthwise Op, but may lead poor "
"performance.";
}
// check input
......@@ -77,7 +77,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
// create filter node
CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(lite::npu::CvtFromLiteTensor(filter));
filter_const_node->set_attr_value(lite::npu::CvtTensor(filter));
lite::npu::OpList::Global().add(filter_const_node);
// create bias node if has bias
......@@ -115,8 +115,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
} else {
// bias node with const data
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(
lite::npu::CvtFromLiteTensor(bias, bias_shape));
bias_const_node->set_attr_value(lite::npu::CvtTensor(bias, bias_shape));
bias_node = bias_const_node;
}
lite::npu::OpList::Global().add(bias_node);
......@@ -193,7 +192,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto relu_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
relu_node->set_input_x(*conv_node);
relu_node->set_attr_mode(1);
relu_node->set_attr_mode(lite::npu::CvtActMode("relu"));
lite::npu::OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node;
} else {
......
......@@ -28,7 +28,7 @@ node_map_type ConvTransposeConverter(
auto op_info = conv_transpose_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << "... ";
LOG(INFO) << "[NPU] Converting " << op_type << "... ";
// get input, output and op attributes
auto input_var_name = op_info->Input("Input").front();
......@@ -72,7 +72,7 @@ node_map_type ConvTransposeConverter(
// create filter node
CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(lite::npu::CvtFromLiteTensor(filter));
filter_const_node->set_attr_value(lite::npu::CvtTensor(filter));
conv_transpose_node->set_input_filter(*filter_const_node);
lite::npu::OpList::Global().add(filter_const_node);
......@@ -107,7 +107,7 @@ node_map_type ConvTransposeConverter(
CHECK_EQ(channel_size, filter_shape[1] * groups);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(
lite::npu::CvtFromLiteTensor(bias, {1, channel_size, 1, 1}));
lite::npu::CvtTensor(bias, {1, channel_size, 1, 1}));
lite::npu::OpList::Global().add(bias_const_node);
// append add node to add bias node
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
......@@ -123,7 +123,7 @@ node_map_type ConvTransposeConverter(
auto relu_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
relu_node->set_input_x(*output_node);
relu_node->set_attr_mode(1);
relu_node->set_attr_mode(lite::npu::CvtActMode("relu"));
lite::npu::OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node;
} else {
......
......@@ -28,7 +28,7 @@ node_map_type ElementwiseConverter(
auto op_info = elementwise_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "converting elementwise...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Eltwise> elementwise_node =
std::make_shared<ge::op::Eltwise>(unique_op_type);
......@@ -37,7 +37,7 @@ node_map_type ElementwiseConverter(
auto y_var_name = op_info->Input("Y").front();
CHECK_EQ(op_info->GetAttr<int>("axis"), -1)
<< "npu elementwise only support inputs with same size";
<< "[NPU] elementwise only support inputs with same size";
CHECK(inputs_map.find(x_var_name) != inputs_map.end());
elementwise_node->set_input_x1(*inputs_map.at(x_var_name));
......@@ -47,11 +47,11 @@ node_map_type ElementwiseConverter(
elementwise_node->set_input_x2(*inputs_map.at(y_var_name));
lite::npu::OpList::Global().add(inputs_map.at(y_var_name));
} else {
auto consty = std::make_shared<ge::op::Const>(y_var_name);
auto y_const_node = std::make_shared<ge::op::Const>(y_var_name);
auto* y = scope->FindVar(y_var_name)->GetMutable<Tensor>();
consty->set_attr_value(lite::npu::CvtFromLiteTensor(y));
elementwise_node->set_input_x2(*consty);
lite::npu::OpList::Global().add(consty);
y_const_node->set_attr_value(lite::npu::CvtTensor(y));
elementwise_node->set_input_x2(*y_const_node);
lite::npu::OpList::Global().add(y_const_node);
}
lite::npu::OpList::Global().add(elementwise_node);
......@@ -60,7 +60,19 @@ node_map_type ElementwiseConverter(
elementwise_node->set_attr_mode(1);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = elementwise_node;
if (op_type == "fusion_elementwise_add_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type");
auto act_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/act");
act_node->set_input_x(*elementwise_node);
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
act_node->set_attr_mode(lite::npu::CvtActMode(act_type));
lite::npu::OpList::Global().add(act_node);
outputs_map[op_info->Output("Out").front()] = act_node;
} else {
outputs_map[op_info->Output("Out").front()] = elementwise_node;
}
return outputs_map;
}
......@@ -72,3 +84,5 @@ node_map_type ElementwiseConverter(
REGISTER_NPU_BRIDGE(elementwise_add,
paddle::lite::kernels::npu::bridges::ElementwiseConverter);
REGISTER_NPU_BRIDGE(fusion_elementwise_add_activation,
paddle::lite::kernels::npu::bridges::ElementwiseConverter);
......@@ -27,7 +27,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
auto op_info = fc_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto fc_node = std::make_shared<ge::op::FullConnection>(unique_op_type);
......@@ -47,7 +47,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
int k = x_dims.Slice(in_num_col_dims, x_dims.size()).production();
int n = w_dims[1];
CHECK_EQ(k * n, w_dims.production());
VLOG(3) << "x dims: " << x_dims << " w dims: " << w_dims << " m: " << m
VLOG(3) << "[NPU] x dims: " << x_dims << " w dims: " << w_dims << " m: " << m
<< " k: " << k << " n: " << n;
CHECK(inputs_map.count(x_var_name));
......@@ -92,8 +92,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
CHECK_EQ(bias_dims.production(), n);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(
lite::npu::CvtFromLiteTensor(bias, {1, n, 1, 1}));
bias_const_node->set_attr_value(lite::npu::CvtTensor(bias, {1, n, 1, 1}));
fc_node->set_input_b(*bias_const_node);
lite::npu::OpList::Global().add(bias_const_node);
}
......
......@@ -28,7 +28,7 @@ node_map_type InterpolateConverter(
auto op_info = interpolate_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
// get input, output and attributes from lite op
auto x_var_name = op_info->Input("X").front();
......@@ -45,8 +45,9 @@ node_map_type InterpolateConverter(
auto out_h = op_info->GetAttr<int>("out_h");
auto align_corners = op_info->GetAttr<bool>("align_corners");
int align_mode = op_info->GetAttr<int>("align_mode");
CHECK(!(align_mode == 0 && !align_corners))
<< "align_mode = 0 && align_corners = false isn't supported in NPU DDK";
CHECK(!(align_mode == 0 && !align_corners)) << "[NPU] align_mode = 0 && "
"align_corners = false isn't "
"supported in HiAI DDK";
// priority: OutSize > scale > out_h/out_w
if (scale > 0) {
......@@ -87,9 +88,9 @@ node_map_type InterpolateConverter(
const float largest_multiple = 7.0f;
float multiple = static_cast<float>(x_h * x_w) / (out_h * out_w);
CHECK_LT(multiple, largest_multiple)
<< "multiple=(ih*iw)/(oh*ow)=" << multiple
<< "[NPU] multiple=(ih*iw)/(oh*ow)=" << multiple
<< " is too large, should not exceed " << largest_multiple
<< " in NPU DDK";
<< " in HiAI DDK";
auto w_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/w");
w_const_node->set_attr_value(
......@@ -121,7 +122,7 @@ node_map_type InterpolateConverter(
interp_node->set_attr_align_corners(align_corners);
outputs_map[op_info->Output("Out").front()] = interp_node;
} else {
LOG(FATAL) << "unsupported interpolate method: " << interp_method;
LOG(FATAL) << "[NPU] Unsupported interpolate method: " << interp_method;
}
return outputs_map;
......
......@@ -25,11 +25,13 @@ namespace bridges {
// handle in this converter
node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
const node_map_type& inputs_map) {
LOG(INFO) << "converting mul...";
lite::Scope* scope = mul_op->scope();
const lite::OpInfo* op_info = mul_op->op_info();
auto output_node =
std::make_shared<ge::op::MatMul>(lite::npu::UniqueName("mul"));
auto scope = mul_op->scope();
auto op_info = mul_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto output_node = std::make_shared<ge::op::MatMul>(unique_op_type);
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
......@@ -46,7 +48,7 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
int n = ytensor->dims()
.Slice(y_num_col_dims, ytensor->dims().size())
.production();
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
CHECK_EQ(x_w, y_h) << "[NPU] x_w must be equal with y_h";
int k = x_w;
LOG(INFO) << "m:" << m << ",n:" << n << ",k:" << k;
LOG(INFO) << "x_var_name:" << x_var_name
......
......@@ -27,7 +27,7 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op,
auto op_info = pad2d_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Pad> pad2d_node =
std::make_shared<ge::op::Pad>(unique_op_type);
......@@ -40,10 +40,10 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op,
if (mode == "constant") {
pad2d_node->set_attr_mode(0);
} else if (mode == "reflect") {
LOG(FATAL) << "NPU doesn't support this pad mod: " << mode;
LOG(FATAL) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
pad2d_node->set_attr_mode(1);
} else {
LOG(FATAL) << "NPU doesn't support this pad mod: " << mode;
LOG(FATAL) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
}
auto x_dims = scope->FindTensor(x_var_name)->dims();
......
......@@ -23,6 +23,7 @@ USE_NPU_BRIDGE(depthwise_conv2d);
USE_NPU_BRIDGE(pool2d);
USE_NPU_BRIDGE(relu);
USE_NPU_BRIDGE(elementwise_add);
USE_NPU_BRIDGE(fusion_elementwise_add_activation);
USE_NPU_BRIDGE(scale);
USE_NPU_BRIDGE(softmax);
USE_NPU_BRIDGE(concat);
......
......@@ -27,7 +27,7 @@ node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op,
auto op_info = pool_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Pooling> pool_node =
std::make_shared<ge::op::Pooling>(unique_op_type);
......@@ -39,9 +39,9 @@ node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op,
} else if (pooling_type == "avg") {
npu_mode = 1;
CHECK(op_info->GetAttr<bool>("exclusive"))
<< "exclusive must be true when use npu";
<< "[NPU] exclusive must be true in HiAI DDK";
} else {
LOG(FATAL) << "Unsupported pooling type: " << pooling_type;
LOG(FATAL) << "[NPU] Unsupported pooling type: " << pooling_type;
}
bool npu_global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
......
......@@ -28,7 +28,7 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
auto op_info = reshape_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
// get input, output and op attributes
auto x_var_name = op_info->Input("X").front();
......@@ -55,9 +55,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
auto out_dims = operators::ValidateShape(shape, x_dims);
auto out_shape = out_dims.Vectorize();
if (out_shape.size() > 4) {
LOG(WARNING)
<< "NPU DDK only supports less than 4 dimensions, but Shape has "
<< out_shape.size();
LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
"but Shape has "
<< out_shape.size();
}
auto actual_shape_const_node =
std::make_shared<ge::op::Const>(actual_shape_var_name);
......@@ -75,9 +75,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
auto out_dims = operators::ValidateShape(shape, x_dims);
auto out_shape = out_dims.Vectorize();
if (out_shape.size() > 4) {
LOG(WARNING)
<< "NPU DDK only supports less than 4 dimensions, but shape has "
<< out_shape.size();
LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
"but shape has "
<< out_shape.size();
}
reshape_node->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
......@@ -93,9 +93,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
xshape_dims[i + 1] = x_dims[i];
}
if (xshape_dims.size() > 4) {
LOG(WARNING)
<< "NPU DDK only supports less than 4 dimensions, but XShape has "
<< xshape_dims.size();
LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
"but XShape has "
<< xshape_dims.size();
}
auto xshape_node =
std::make_shared<ge::op::Reshape>(unique_op_type + "/xshape");
......
......@@ -27,7 +27,7 @@ node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
auto op_info = scale_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
// get input, output and op attributes
auto x_var_name = op_info->Input("X").front();
......
......@@ -28,7 +28,7 @@ node_map_type ShuffleChannelConverter(
auto op_info = shuffle_channel_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::ShuffleChannel> shuffle_channel_node =
std::make_shared<ge::op::ShuffleChannel>(unique_op_type);
......
......@@ -27,7 +27,7 @@ node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op,
auto op_info = softmax_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Softmax> softmax_node =
std::make_shared<ge::op::Softmax>(unique_op_type);
......@@ -37,7 +37,7 @@ node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op,
auto axis = op_info->GetAttr<int>("axis");
if (x_dims.size() > 3) {
CHECK(!(axis == 2 && x_dims[3] > 1))
<< "unsupported npu softmax params: axis = " << axis
<< "[NPU] Unsupported softmax params: axis = " << axis
<< " :x_w = " << x_dims[3];
}
......
......@@ -73,7 +73,7 @@ void softmax_ref(const std::shared_ptr<operators::SoftmaxOp> op) {
}
}
void test_softmax(int bs, int ic, int ih, int iw, int axis) {
void test_softmax(const std::vector<int64_t>& input_shape, int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
......@@ -82,7 +82,7 @@ void test_softmax(int bs, int ic, int ih, int iw, int axis) {
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x);
......@@ -111,19 +111,36 @@ void test_softmax(int bs, int ic, int ih, int iw, int axis) {
}
TEST(NPUBridges, softmax) {
for (auto bs : {1, 4, 7}) {
for (auto ic : {1, 4, 7}) {
for (auto ih : {1, 4, 7}) {
for (auto iw : {1, 4, 7}) {
for (auto axis : {-3, -1, 0, 1, 2, 3}) {
// npu softmax exists bugs when axis is 2 and iw > 1
if (axis == 2 && iw > 1) continue;
test_softmax(bs, ic, ih, iw, axis);
}
}
}
}
}
test_softmax({1, 4}, -1);
// Bug exists in HiAI DDK when the number of items > 16500
// test_softmax({1, 16500}, -1);
test_softmax({1, 4}, 0);
test_softmax({1, 4}, 1);
test_softmax({3, 4}, -1);
test_softmax({3, 4}, 0);
test_softmax({3, 4}, 1);
test_softmax({1, 4, 7}, -1);
test_softmax({1, 4, 7}, 0);
// Bug exists in HiAI DDK when axis is 1 and iw > 1
// test_softmax({1, 4, 7}, 1);
test_softmax({1, 4, 1}, 1);
test_softmax({1, 4, 7}, 2);
test_softmax({3, 4, 7}, -1);
test_softmax({3, 4, 7}, 0);
test_softmax({3, 4, 1}, 1);
test_softmax({3, 4, 7}, 2);
test_softmax({1, 4, 7, 9}, -1);
test_softmax({1, 4, 7, 9}, 0);
test_softmax({1, 4, 7, 9}, 1);
// Bug exists in HiAI DDK when axis is 2 and iw > 1
// test_softmax({1, 4, 7, 9}, 2);
test_softmax({1, 4, 7, 1}, 2);
test_softmax({1, 4, 7, 9}, 3);
test_softmax({3, 4, 7, 9}, -1);
test_softmax({3, 4, 7, 9}, 0);
test_softmax({3, 4, 7, 9}, 1);
test_softmax({3, 4, 7, 1}, 2);
test_softmax({3, 4, 7, 9}, 3);
}
} // namespace bridges
......
......@@ -27,7 +27,7 @@ node_map_type SplitConverter(const std::shared_ptr<lite::OpLite> split_op,
const lite::OpInfo* op_info = split_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << " ... ";
LOG(INFO) << "[NPU] Converting " << op_type << " ... ";
auto x_var_name = op_info->Input("X").front();
auto axis = op_info->GetAttr<int>("axis");
......
......@@ -28,7 +28,7 @@ node_map_type TransposeConverter(
auto op_info = transpose_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Permute> transpose_node =
std::make_shared<ge::op::Permute>(unique_op_type);
......@@ -44,7 +44,7 @@ node_map_type TransposeConverter(
w_data[i] = 1.f;
}
auto npu_w = std::make_shared<ge::op::Const>(w_var_name);
npu_w->set_attr_value(lite::npu::CvtFromLiteTensor(w));
npu_w->set_attr_value(lite::npu::CvtTensor(w));
lite::npu::OpList::Global().add(npu_w);
auto axis = op_info->GetAttr<std::vector<int>>("axis");
......
......@@ -15,11 +15,6 @@
#include "lite/kernels/npu/graph_compute.h"
#include <sys/time.h>
#include <time.h>
#include <string>
#include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
......@@ -30,6 +25,8 @@ void GraphCompute::PrepareForRun() {
auto& ctx = this->ctx_->template As<NPUContext>();
auto& param = this->Param<param_t>();
// Load HiAI model from the weight tensor and release its buffer
// to save memory
CHECK(param.weight);
CHECK(lite::npu::LoadModel(*param.weight, &model_client_, &model_name_));
// TODO(hong19860320): find an good way to free the model data.
......@@ -38,84 +35,75 @@ void GraphCompute::PrepareForRun() {
param.weight->Resize({1});
param.weight->mutable_data<int8_t>(TargetType::kARM);
CHECK(model_client_);
int ret =
model_client_->GetModelIOTensorDim(model_name_, npu_idims_, npu_odims_);
CHECK_EQ(ret, hiai::AI_SUCCESS) << "[NPU] Get dims failed.";
npu_itensors_.resize(npu_idims_.size());
npu_otensors_.resize(npu_odims_.size());
for (size_t i = 0; i < npu_idims_.size(); ++i) {
VLOG(3) << "npu_idims[" << i << "]: " << npu_idims_[i].GetNumber() << ","
<< npu_idims_[i].GetChannel() << "," << npu_idims_[i].GetHeight()
<< "," << npu_idims_[i].GetWidth();
VLOG(3) << "lite_idims[" << i << "]: " << param.inputs[i].second->dims();
CHECK_EQ(param.inputs[i].second->dims().production(),
npu_idims_[i].GetNumber() * npu_idims_[i].GetChannel() *
npu_idims_[i].GetHeight() * npu_idims_[i].GetWidth());
// Query the dimensions of NPU input and output tensors from HiAI model
std::vector<hiai::TensorDimension> npu_idims;
std::vector<hiai::TensorDimension> npu_odims;
int ret =
model_client_->GetModelIOTensorDim(model_name_, npu_idims, npu_odims);
CHECK_EQ(ret, hiai::AI_SUCCESS)
<< "[NPU] Get the dimensions of input and output tensors failed.";
// Check whether the data sizes of NPU input and output tensors are the
// same as CPU's, then create and initialize NPU input and output tensors.
npu_itensors_.resize(npu_idims.size());
npu_otensors_.resize(npu_odims.size());
npu_idatasizes_.resize(npu_idims.size());
npu_odatasizes_.resize(npu_odims.size());
for (size_t i = 0; i < npu_idims.size(); ++i) {
auto cpu_itensor = param.inputs[i].second;
CHECK(cpu_itensor);
VLOG(3) << "[NPU] CPU input dims[" << i << "]: " << cpu_itensor->dims();
VLOG(3) << "[NPU] NPU input dims[" << i << "]: {"
<< npu_idims[i].GetNumber() << "," << npu_idims[i].GetChannel()
<< "," << npu_idims[i].GetHeight() << "," << npu_idims[i].GetWidth()
<< "}";
npu_idatasizes_[i] = npu_idims[i].GetNumber() * npu_idims[i].GetChannel() *
npu_idims[i].GetHeight() * npu_idims[i].GetWidth();
CHECK_EQ(cpu_itensor->dims().production(), npu_idatasizes_[i]);
npu_itensors_[i].reset(new hiai::AiTensor);
npu_itensors_[i]->Init(&(npu_idims_[i]));
npu_itensors_[i]->Init(&(npu_idims[i]));
}
for (size_t i = 0; i < npu_odims_.size(); ++i) {
VLOG(3) << "npu_odims[" << i << "]: " << npu_odims_[i].GetNumber() << ","
<< npu_odims_[i].GetChannel() << "," << npu_odims_[i].GetHeight()
<< "," << npu_odims_[i].GetWidth();
VLOG(3) << "lite_odims[" << i << "]: " << param.outputs[i].second->dims();
auto out_size = npu_odims_[i].GetNumber() * npu_odims_[i].GetChannel() *
npu_odims_[i].GetHeight() * npu_odims_[i].GetWidth();
if (param.outputs[i].second->dims().production() != out_size) {
param.outputs[i].second->Resize({npu_odims_[i].GetNumber(),
npu_odims_[i].GetChannel(),
npu_odims_[i].GetHeight(),
npu_odims_[i].GetWidth()});
for (size_t i = 0; i < npu_odims.size(); ++i) {
auto cpu_otensor = param.outputs[i].second;
CHECK(cpu_otensor);
VLOG(3) << "[NPU] CPU output dims[" << i << "]: " << cpu_otensor->dims();
VLOG(3) << "[NPU] NPU output dims[" << i << "]: {"
<< npu_odims[i].GetNumber() << "," << npu_odims[i].GetChannel()
<< "," << npu_odims[i].GetHeight() << "," << npu_odims[i].GetWidth()
<< "}";
npu_odatasizes_[i] = npu_odims[i].GetNumber() * npu_odims[i].GetChannel() *
npu_odims[i].GetHeight() * npu_odims[i].GetWidth();
if (cpu_otensor->dims().production() != npu_odatasizes_[i]) {
cpu_otensor->Resize({npu_odims[i].GetNumber(),
npu_odims[i].GetChannel(),
npu_odims[i].GetHeight(),
npu_odims[i].GetWidth()});
}
LOG(INFO) << param.outputs[i].second->dims();
npu_otensors_[i].reset(new hiai::AiTensor);
npu_otensors_[i]->Init(&(npu_odims_[i]));
}
}
bool GraphCompute::input_dims_changed() const {
auto& param = this->Param<param_t>();
CHECK_EQ(param.inputs.size(), npu_idims_.size());
for (size_t i = 0; i < param.inputs.size(); ++i) {
auto param_idims = param.inputs[i].second->dims();
CHECK(!param_idims.empty());
CHECK_EQ(param_idims.size(), 4);
std::vector<int> idims{static_cast<int>(npu_idims_[i].GetNumber()),
static_cast<int>(npu_idims_[i].GetChannel()),
static_cast<int>(npu_idims_[i].GetHeight()),
static_cast<int>(npu_idims_[i].GetWidth())};
for (size_t i = 0; i < 4; ++i) {
if (param_idims[i] != idims[i]) {
return true;
}
}
npu_otensors_[i]->Init(&(npu_odims[i]));
}
return false;
}
void GraphCompute::Run() {
CHECK(!input_dims_changed())
<< "When NPU is enabled, the input shape could not be changed yet.";
auto& param = this->Param<param_t>();
// Check whether the data sizes of NPU input tensors are the same as
// CPU's, and copy the data of CPU input tensors to NPU's.
CHECK_EQ(param.inputs.size(), npu_itensors_.size());
CHECK_EQ(param.outputs.size(), npu_otensors_.size());
for (size_t i = 0; i < param.inputs.size(); ++i) {
auto* itensor = param.inputs[i].second;
CHECK(itensor);
const auto* i_data = itensor->data<float>();
std::memcpy(
npu_itensors_[i]->GetBuffer(),
i_data,
sizeof(float) * static_cast<size_t>(itensor->dims().production()));
auto cpu_itensor = param.inputs[i].second;
CHECK(cpu_itensor);
CHECK_EQ(cpu_itensor->dims().production(), npu_idatasizes_[i]);
std::memcpy(static_cast<float*>(npu_itensors_[i]->GetBuffer()),
cpu_itensor->data<float>(),
sizeof(float) * static_cast<size_t>(npu_idatasizes_[i]));
}
// Run HiAI model with model name
std::string key = "model_name"; // Note: key seems must be model_name
model_context_.AddPara(key, model_name_);
auto GetCurrentUS = []() -> double {
struct timeval time;
gettimeofday(&time, NULL);
......@@ -128,16 +116,15 @@ void GraphCompute::Run() {
model_context_, npu_itensors_, npu_otensors_, 1000, istamp));
VLOG(3) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us";
// Check whether the data sizes of NPU output tensors are the same as
// CPU's, and copy the data of NPU output tensors to CPU's.
for (size_t i = 0; i < param.outputs.size(); ++i) {
auto* otensor = param.outputs[i].second;
CHECK(otensor);
auto* o_data = otensor->mutable_data<float>();
auto* npu_obuffer = static_cast<float*>(npu_otensors_[i]->GetBuffer());
std::memcpy(
o_data,
npu_obuffer,
sizeof(float) * static_cast<size_t>(otensor->dims().production()));
auto cpu_otensor = param.outputs[i].second;
CHECK(cpu_otensor);
CHECK_EQ(cpu_otensor->dims().production(), npu_odatasizes_[i]);
std::memcpy(cpu_otensor->mutable_data<float>(),
static_cast<float*>(npu_otensors_[i]->GetBuffer()),
sizeof(float) * static_cast<size_t>(npu_odatasizes_[i]));
}
}
......
......@@ -37,16 +37,13 @@ class GraphCompute : public KernelLite<TARGET(kNPU), PRECISION(kFloat)> {
virtual ~GraphCompute() = default;
bool input_dims_changed() const;
private:
std::shared_ptr<hiai::AiModelMngerClient> model_client_;
std::string model_name_;
hiai::AiContext model_context_;
std::vector<hiai::TensorDimension> npu_idims_;
std::vector<hiai::TensorDimension> npu_odims_;
std::vector<int64_t> npu_idatasizes_;
std::vector<int64_t> npu_odatasizes_;
std::vector<std::shared_ptr<hiai::AiTensor>> npu_itensors_;
std::vector<std::shared_ptr<hiai::AiTensor>> npu_otensors_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册