未验证 提交 aacba6f5 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU] Add fusion_elementwise_add_activation bridge, remove the...

[LITE][NPU] Add fusion_elementwise_add_activation bridge, remove the limitation of the dimensions of input tensors in the graph compute kernel, and refine log message (#2395)

test=develop
上级 f51c4891
...@@ -67,7 +67,7 @@ std::string UniqueName(const std::string& prefix) { ...@@ -67,7 +67,7 @@ std::string UniqueName(const std::string& prefix) {
return prefix + "_" + std::to_string(counter); return prefix + "_" + std::to_string(counter);
} }
ge::DataType PrecisionConverter(PrecisionType itype) { ge::DataType CvtPrecisionType(PrecisionType itype) {
ge::DataType otype = ge::DT_FLOAT; ge::DataType otype = ge::DT_FLOAT;
switch (itype) { switch (itype) {
case PRECISION(kFloat): case PRECISION(kFloat):
...@@ -80,14 +80,14 @@ ge::DataType PrecisionConverter(PrecisionType itype) { ...@@ -80,14 +80,14 @@ ge::DataType PrecisionConverter(PrecisionType itype) {
otype = ge::DT_INT32; otype = ge::DT_INT32;
break; break;
default: default:
LOG(FATAL) << "Can not convert precision type(" << PrecisionToStr(itype) LOG(FATAL) << "[NPU] Can not convert precision type("
<< ") from Lite to NPU"; << PrecisionToStr(itype) << ") from Lite to NPU";
break; break;
} }
return otype; return otype;
} }
ge::Format DataLayoutConverter(DataLayoutType itype) { ge::Format CvtDataLayoutType(DataLayoutType itype) {
ge::Format otype = ge::FORMAT_NCHW; ge::Format otype = ge::FORMAT_NCHW;
switch (itype) { switch (itype) {
case DATALAYOUT(kNCHW): case DATALAYOUT(kNCHW):
...@@ -95,17 +95,17 @@ ge::Format DataLayoutConverter(DataLayoutType itype) { ...@@ -95,17 +95,17 @@ ge::Format DataLayoutConverter(DataLayoutType itype) {
break; break;
// TODO(hong19860320) support more data layout type // TODO(hong19860320) support more data layout type
default: default:
LOG(FATAL) << "Can not convert data layout type(" LOG(FATAL) << "[NPU] Can not convert data layout type("
<< DataLayoutToStr(itype) << ") from Lite to NPU"; << DataLayoutToStr(itype) << ") from Lite to NPU";
break; break;
} }
return otype; return otype;
} }
ge::TensorPtr CvtFromLiteTensor(lite::Tensor* in_tensor, ge::TensorPtr CvtTensor(lite::Tensor* in_tensor,
std::vector<int64_t> out_shape, std::vector<int64_t> out_shape,
PrecisionType in_ptype, PrecisionType in_ptype,
DataLayoutType in_ltype) { DataLayoutType in_ltype) {
uint8_t* in_data = nullptr; uint8_t* in_data = nullptr;
auto in_size = in_tensor->dims().production(); auto in_size = in_tensor->dims().production();
auto in_shape = in_tensor->dims().Vectorize(); auto in_shape = in_tensor->dims().Vectorize();
...@@ -123,10 +123,10 @@ ge::TensorPtr CvtFromLiteTensor(lite::Tensor* in_tensor, ...@@ -123,10 +123,10 @@ ge::TensorPtr CvtFromLiteTensor(lite::Tensor* in_tensor,
in_data = reinterpret_cast<uint8_t*>(in_tensor->mutable_data<int8_t>()); in_data = reinterpret_cast<uint8_t*>(in_tensor->mutable_data<int8_t>());
in_bytes = in_size * sizeof(int8_t); in_bytes = in_size * sizeof(int8_t);
} else { } else {
LOG(FATAL) << "Unknow precision type " << PrecisionToStr(in_ptype); LOG(FATAL) << "[NPU] Unknow precision type " << PrecisionToStr(in_ptype);
} }
ge::DataType out_ptype = PrecisionConverter(in_ptype); ge::DataType out_ptype = CvtPrecisionType(in_ptype);
ge::Format out_ltype = DataLayoutConverter(in_ltype); ge::Format out_ltype = CvtDataLayoutType(in_ltype);
ge::TensorDesc out_desc(ge::Shape(out_shape), out_ltype, out_ptype); ge::TensorDesc out_desc(ge::Shape(out_shape), out_ltype, out_ptype);
CHECK_EQ(out_ltype, ge::FORMAT_NCHW); CHECK_EQ(out_ltype, ge::FORMAT_NCHW);
...@@ -140,6 +140,31 @@ ge::TensorPtr CvtFromLiteTensor(lite::Tensor* in_tensor, ...@@ -140,6 +140,31 @@ ge::TensorPtr CvtFromLiteTensor(lite::Tensor* in_tensor,
return out_tensor; return out_tensor;
} }
int CvtActMode(std::string act_type) {
int act_mode = 1;
if (act_type == "sigmod") {
act_mode = 0;
} else if (act_type == "relu") {
act_mode = 1;
} else if (act_type == "tanh") {
act_mode = 2;
} else if (act_type == "elu") {
act_mode = 4;
} else if (act_type == "abs") {
act_mode = 6;
} else if (act_type == "softsign") {
act_mode = 8;
} else if (act_type == "softplus") {
act_mode = 9;
} else if (act_type == "hardsigmoid") {
act_mode = 10;
} else {
// TODO(hong19860320) support more activation mode
LOG(FATAL) << "[NPU] Unsupported activation type " << act_type;
}
return act_mode;
}
bool HasInputArg(const OpInfo* op_info, bool HasInputArg(const OpInfo* op_info,
const Scope* scope, const Scope* scope,
const std::string& argname) { const std::string& argname) {
......
...@@ -192,14 +192,14 @@ bool BuildModel(std::vector<ge::Operator>& inputs, // NOLINT ...@@ -192,14 +192,14 @@ bool BuildModel(std::vector<ge::Operator>& inputs, // NOLINT
std::string UniqueName(const std::string& prefix); std::string UniqueName(const std::string& prefix);
ge::DataType PrecisionConverter(PrecisionType itype); ge::DataType CvtPrecisionType(PrecisionType itype);
ge::Format DataLayoutConverter(DataLayoutType itype); ge::Format CvtDataLayoutType(DataLayoutType itype);
ge::TensorPtr CvtFromLiteTensor(Tensor* in_tensor, ge::TensorPtr CvtTensor(Tensor* in_tensor,
std::vector<int64_t> out_shape = {}, std::vector<int64_t> out_shape = {},
PrecisionType in_ptype = PRECISION(kFloat), PrecisionType in_ptype = PRECISION(kFloat),
DataLayoutType in_ltype = DATALAYOUT(kNCHW)); DataLayoutType in_ltype = DATALAYOUT(kNCHW));
template <typename T> template <typename T>
ge::TensorPtr CreateTensorAndFillData(std::vector<T> data, ge::TensorPtr CreateTensorAndFillData(std::vector<T> data,
...@@ -214,7 +214,7 @@ ge::TensorPtr CreateTensorAndFillData(std::vector<T> data, ...@@ -214,7 +214,7 @@ ge::TensorPtr CreateTensorAndFillData(std::vector<T> data,
} else if (info == typeid(int32_t)) { } else if (info == typeid(int32_t)) {
type = ge::DT_INT32; type = ge::DT_INT32;
} else { } else {
LOG(FATAL) << "Unknow value type " << info.name(); LOG(FATAL) << "[NPU] Unknow value type " << info.name();
} }
if (shape.empty()) { if (shape.empty()) {
shape = {static_cast<int64_t>(data.size())}; shape = {static_cast<int64_t>(data.size())};
...@@ -245,6 +245,8 @@ ge::TensorPtr CreateTensorAndFillData(T value, ...@@ -245,6 +245,8 @@ ge::TensorPtr CreateTensorAndFillData(T value,
return CreateTensorAndFillData(data, shape, format); return CreateTensorAndFillData(data, shape, format);
} }
int CvtActMode(std::string act_type);
bool HasInputArg(const OpInfo* op_info, bool HasInputArg(const OpInfo* op_info,
const Scope* scope, const Scope* scope,
const std::string& argname); const std::string& argname);
......
...@@ -35,7 +35,7 @@ std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode( ...@@ -35,7 +35,7 @@ std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode(
lite::mir::Node* var_node, const Scope* scope) { lite::mir::Node* var_node, const Scope* scope) {
CHECK(var_node->IsArg()); CHECK(var_node->IsArg());
const auto& arg = var_node->AsArg(); const auto& arg = var_node->AsArg();
VLOG(4) << "Convert var node " << arg.name; VLOG(4) << "[NPU] Convert var node " << arg.name;
auto* var = scope->FindVar(arg.name); auto* var = scope->FindVar(arg.name);
CHECK(var); CHECK(var);
...@@ -44,13 +44,13 @@ std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode( ...@@ -44,13 +44,13 @@ std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode(
auto dims = tensor->dims(); auto dims = tensor->dims();
if (arg.is_weight) { if (arg.is_weight) {
auto wgt = std::make_shared<ge::op::Const>(arg.name); auto wgt = std::make_shared<ge::op::Const>(arg.name);
LOG(INFO) << "in convert const:" << arg.name; LOG(INFO) << "[NPU] Convert const var node " << arg.name;
VLOG(4) << dims; VLOG(4) << dims;
wgt->set_attr_value(lite::npu::CvtFromLiteTensor(tensor)); wgt->set_attr_value(lite::npu::CvtTensor(tensor));
return wgt; return wgt;
} else { } else {
CHECK_EQ(dims.size(), 4); CHECK_EQ(dims.size(), 4);
LOG(INFO) << "in convert data:" << arg.name; LOG(INFO) << "[NPU] Convert data var node " << arg.name;
LOG(INFO) << dims; LOG(INFO) << dims;
// TODO(xxx): support more types and dims size // TODO(xxx): support more types and dims size
ge::TensorDesc desc(ge::Shape(dims.Vectorize()), ge::TensorDesc desc(ge::Shape(dims.Vectorize()),
...@@ -128,10 +128,10 @@ std::string GenerateNPUProgramPass::BuildNPUGraph( ...@@ -128,10 +128,10 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
// persistable=true, Sothat the model parser can recognize it and save it to // persistable=true, Sothat the model parser can recognize it and save it to
// param files // param files
if (!lite::npu::BuildModel(inputs, outputs, weight)) { if (!lite::npu::BuildModel(inputs, outputs, weight)) {
LOG(WARNING) << "Build NPU failed subgraph " << sub_id; LOG(WARNING) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")";
throw std::runtime_error("Build NPU failed subgraph."); throw std::runtime_error("Build NPU graph failed.");
} }
LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id; LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
return weight_var_name; return weight_var_name;
} }
...@@ -166,12 +166,12 @@ void GenerateNPUProgramPass::GenNPUSubgraph( ...@@ -166,12 +166,12 @@ void GenerateNPUProgramPass::GenNPUSubgraph(
} }
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get()); LOG(INFO) << "[NPU] Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); const auto& bridges = lite::kernels::npu::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions(); const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types; std::vector<std::string> supported_op_types;
for (auto& i : op_map) { for (auto& i : op_map) {
LOG(INFO) << "Supported type: " << i.first; LOG(INFO) << "[NPU] Supported type: " << i.first;
supported_op_types.push_back(i.first); supported_op_types.push_back(i.first);
} }
...@@ -182,15 +182,15 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -182,15 +182,15 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
CHECK_EQ(op_nodes_all.size(), num_subgraph); CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1; int id = 1;
for (auto& op_nodes : op_nodes_all) { for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "Converting subgraph_id:" << id; LOG(INFO) << "[NPU] Converting Subgraph " << id;
GenNPUSubgraph(graph, op_nodes.second, id); GenNPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "After NPU Pass Subgraph " << id << "\n" LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get()); << Visualize(graph.get());
id++; id++;
} }
} catch (...) { } catch (...) {
LOG(WARNING) << "Build NPU graph failed"; LOG(WARNING) << "[NPU] Build NPU graph failed.";
throw std::runtime_error("Build NPU graph failed"); throw std::runtime_error("[NPU] Build NPU graph failed.");
} }
for (auto& item : graph->StmtTopologicalOrder()) { for (auto& item : graph->StmtTopologicalOrder()) {
...@@ -203,7 +203,7 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -203,7 +203,7 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} }
std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() { std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
LOG(INFO) << "insts.size " << insts_.size(); LOG(INFO) << "[NPU] program insts.size " << insts_.size();
std::unique_ptr<RuntimeProgram> program( std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_))); new RuntimeProgram(std::move(insts_)));
return program; return program;
......
...@@ -27,7 +27,7 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op, ...@@ -27,7 +27,7 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
auto op_info = act_op->op_info(); auto op_info = act_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
// create act node and set input node from inputs_map // create act node and set input node from inputs_map
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
...@@ -37,30 +37,9 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op, ...@@ -37,30 +37,9 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(act_node); lite::npu::OpList::Global().add(act_node);
// parse and set activation type // TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
int act_mode = 1; // clipped_relu etc.
if (op_type == "sigmod") { act_node->set_attr_mode(lite::npu::CvtActMode(op_type));
act_mode = 0;
} else if (op_type == "relu") {
act_mode = 1;
} else if (op_type == "tanh") {
act_mode = 2;
} else if (op_type == "elu") {
act_mode = 4;
} else if (op_type == "abs") {
act_mode = 6;
} else if (op_type == "softsign") {
act_mode = 8;
} else if (op_type == "softplus") {
act_mode = 9;
} else if (op_type == "hardsigmoid") {
act_mode = 10;
} else {
// TODO(hong19860320) add more activation mode, and set the coef value
// clipped ReLU, LEAKY_RELU, relu1, threshold, selu and linear
LOG(FATAL) << "Unsupported activation type " << op_type;
}
act_node->set_attr_mode(act_mode);
node_map_type outputs_map; node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = act_node; outputs_map[op_info->Output("Out").front()] = act_node;
......
...@@ -28,7 +28,7 @@ node_map_type BatchNormConverter( ...@@ -28,7 +28,7 @@ node_map_type BatchNormConverter(
auto op_info = batch_norm_op->op_info(); auto op_info = batch_norm_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::BatchNorm> batch_norm_node = std::shared_ptr<ge::op::BatchNorm> batch_norm_node =
std::make_shared<ge::op::BatchNorm>(unique_op_type); std::make_shared<ge::op::BatchNorm>(unique_op_type);
...@@ -37,26 +37,26 @@ node_map_type BatchNormConverter( ...@@ -37,26 +37,26 @@ node_map_type BatchNormConverter(
auto scale_var_name = op_info->Input("Scale").front(); auto scale_var_name = op_info->Input("Scale").front();
lite::Tensor* scale = scope->FindVar(scale_var_name)->GetMutable<Tensor>(); lite::Tensor* scale = scope->FindVar(scale_var_name)->GetMutable<Tensor>();
auto npu_scale = std::make_shared<ge::op::Const>(scale_var_name); auto npu_scale = std::make_shared<ge::op::Const>(scale_var_name);
npu_scale->set_attr_value(lite::npu::CvtFromLiteTensor(scale)); npu_scale->set_attr_value(lite::npu::CvtTensor(scale));
lite::npu::OpList::Global().add(npu_scale); lite::npu::OpList::Global().add(npu_scale);
auto bias_var_name = op_info->Input("Bias").front(); auto bias_var_name = op_info->Input("Bias").front();
lite::Tensor* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>(); lite::Tensor* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto npu_bias = std::make_shared<ge::op::Const>(bias_var_name); auto npu_bias = std::make_shared<ge::op::Const>(bias_var_name);
npu_bias->set_attr_value(lite::npu::CvtFromLiteTensor(bias)); npu_bias->set_attr_value(lite::npu::CvtTensor(bias));
lite::npu::OpList::Global().add(npu_bias); lite::npu::OpList::Global().add(npu_bias);
auto mean_var_name = op_info->Input("Mean").front(); auto mean_var_name = op_info->Input("Mean").front();
lite::Tensor* mean = scope->FindVar(mean_var_name)->GetMutable<Tensor>(); lite::Tensor* mean = scope->FindVar(mean_var_name)->GetMutable<Tensor>();
auto npu_mean = std::make_shared<ge::op::Const>(mean_var_name); auto npu_mean = std::make_shared<ge::op::Const>(mean_var_name);
npu_mean->set_attr_value(lite::npu::CvtFromLiteTensor(mean)); npu_mean->set_attr_value(lite::npu::CvtTensor(mean));
lite::npu::OpList::Global().add(npu_mean); lite::npu::OpList::Global().add(npu_mean);
auto variance_var_name = op_info->Input("Variance").front(); auto variance_var_name = op_info->Input("Variance").front();
lite::Tensor* variance = lite::Tensor* variance =
scope->FindVar(variance_var_name)->GetMutable<Tensor>(); scope->FindVar(variance_var_name)->GetMutable<Tensor>();
auto npu_variance = std::make_shared<ge::op::Const>(variance_var_name); auto npu_variance = std::make_shared<ge::op::Const>(variance_var_name);
npu_variance->set_attr_value(lite::npu::CvtFromLiteTensor(variance)); npu_variance->set_attr_value(lite::npu::CvtTensor(variance));
lite::npu::OpList::Global().add(npu_variance); lite::npu::OpList::Global().add(npu_variance);
float npu_momentum = op_info->GetAttr<float>("momentum"); float npu_momentum = op_info->GetAttr<float>("momentum");
......
...@@ -27,7 +27,7 @@ node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op, ...@@ -27,7 +27,7 @@ node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op,
const lite::OpInfo* op_info = concat_op->op_info(); const lite::OpInfo* op_info = concat_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "converting " << op_type << " ... "; LOG(INFO) << "[NPU] Converting " << op_type << " ... ";
auto x_var_names = op_info->Input("X"); auto x_var_names = op_info->Input("X");
auto axis = op_info->GetAttr<int>("axis"); auto axis = op_info->GetAttr<int>("axis");
...@@ -46,7 +46,7 @@ node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op, ...@@ -46,7 +46,7 @@ node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op,
} else { } else {
auto consty = std::make_shared<ge::op::Const>(x_var_name); auto consty = std::make_shared<ge::op::Const>(x_var_name);
auto* x = scope->FindVar(x_var_name)->GetMutable<Tensor>(); auto* x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
consty->set_attr_value(lite::npu::CvtFromLiteTensor(x)); consty->set_attr_value(lite::npu::CvtTensor(x));
output_node->set_dynamic_input_x(index + 1, *consty); output_node->set_dynamic_input_x(index + 1, *consty);
lite::npu::OpList::Global().add(consty); lite::npu::OpList::Global().add(consty);
} }
......
...@@ -27,7 +27,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op, ...@@ -27,7 +27,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto op_info = conv_op->op_info(); auto op_info = conv_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << "... "; LOG(INFO) << "[NPU] Converting " << op_type << "... ";
// get input, filter and op attributes // get input, filter and op attributes
auto input_var_name = op_info->Input("Input").front(); auto input_var_name = op_info->Input("Input").front();
...@@ -64,10 +64,10 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op, ...@@ -64,10 +64,10 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
!((groups == 1 || groups >= 5) && dilations[0] == 1 && !((groups == 1 || groups >= 5) && dilations[0] == 1 &&
dilations[1] == 1)) { dilations[1] == 1)) {
use_depthwise_conv = true; use_depthwise_conv = true;
LOG(WARNING) << "For depthwise mode, dilation = 1 and groups >= 5 (or " LOG(WARNING) << "[NPU] For depthwise mode, dilation = 1 and groups >= 5 "
"groups = 1) is only supported in " "(or groups = 1) is only supported in Convolution Op, so "
"Convolution Op, so force to use ConvolutionDepthwise Op, " "force to use ConvolutionDepthwise Op, but may lead poor "
"but may lead poor performance."; "performance.";
} }
// check input // check input
...@@ -77,7 +77,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op, ...@@ -77,7 +77,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
// create filter node // create filter node
CHECK(!inputs_map.count(filter_var_name)); CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name); auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(lite::npu::CvtFromLiteTensor(filter)); filter_const_node->set_attr_value(lite::npu::CvtTensor(filter));
lite::npu::OpList::Global().add(filter_const_node); lite::npu::OpList::Global().add(filter_const_node);
// create bias node if has bias // create bias node if has bias
...@@ -115,8 +115,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op, ...@@ -115,8 +115,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
} else { } else {
// bias node with const data // bias node with const data
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name); auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value( bias_const_node->set_attr_value(lite::npu::CvtTensor(bias, bias_shape));
lite::npu::CvtFromLiteTensor(bias, bias_shape));
bias_node = bias_const_node; bias_node = bias_const_node;
} }
lite::npu::OpList::Global().add(bias_node); lite::npu::OpList::Global().add(bias_node);
...@@ -193,7 +192,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op, ...@@ -193,7 +192,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto relu_node = auto relu_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/relu"); std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
relu_node->set_input_x(*conv_node); relu_node->set_input_x(*conv_node);
relu_node->set_attr_mode(1); relu_node->set_attr_mode(lite::npu::CvtActMode("relu"));
lite::npu::OpList::Global().add(relu_node); lite::npu::OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node; outputs_map[op_info->Output("Output").front()] = relu_node;
} else { } else {
......
...@@ -28,7 +28,7 @@ node_map_type ConvTransposeConverter( ...@@ -28,7 +28,7 @@ node_map_type ConvTransposeConverter(
auto op_info = conv_transpose_op->op_info(); auto op_info = conv_transpose_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << "... "; LOG(INFO) << "[NPU] Converting " << op_type << "... ";
// get input, output and op attributes // get input, output and op attributes
auto input_var_name = op_info->Input("Input").front(); auto input_var_name = op_info->Input("Input").front();
...@@ -72,7 +72,7 @@ node_map_type ConvTransposeConverter( ...@@ -72,7 +72,7 @@ node_map_type ConvTransposeConverter(
// create filter node // create filter node
CHECK(!inputs_map.count(filter_var_name)); CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name); auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(lite::npu::CvtFromLiteTensor(filter)); filter_const_node->set_attr_value(lite::npu::CvtTensor(filter));
conv_transpose_node->set_input_filter(*filter_const_node); conv_transpose_node->set_input_filter(*filter_const_node);
lite::npu::OpList::Global().add(filter_const_node); lite::npu::OpList::Global().add(filter_const_node);
...@@ -107,7 +107,7 @@ node_map_type ConvTransposeConverter( ...@@ -107,7 +107,7 @@ node_map_type ConvTransposeConverter(
CHECK_EQ(channel_size, filter_shape[1] * groups); CHECK_EQ(channel_size, filter_shape[1] * groups);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name); auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value( bias_const_node->set_attr_value(
lite::npu::CvtFromLiteTensor(bias, {1, channel_size, 1, 1})); lite::npu::CvtTensor(bias, {1, channel_size, 1, 1}));
lite::npu::OpList::Global().add(bias_const_node); lite::npu::OpList::Global().add(bias_const_node);
// append add node to add bias node // append add node to add bias node
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add"); auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
...@@ -123,7 +123,7 @@ node_map_type ConvTransposeConverter( ...@@ -123,7 +123,7 @@ node_map_type ConvTransposeConverter(
auto relu_node = auto relu_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/relu"); std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
relu_node->set_input_x(*output_node); relu_node->set_input_x(*output_node);
relu_node->set_attr_mode(1); relu_node->set_attr_mode(lite::npu::CvtActMode("relu"));
lite::npu::OpList::Global().add(relu_node); lite::npu::OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node; outputs_map[op_info->Output("Output").front()] = relu_node;
} else { } else {
......
...@@ -28,7 +28,7 @@ node_map_type ElementwiseConverter( ...@@ -28,7 +28,7 @@ node_map_type ElementwiseConverter(
auto op_info = elementwise_op->op_info(); auto op_info = elementwise_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "converting elementwise..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Eltwise> elementwise_node = std::shared_ptr<ge::op::Eltwise> elementwise_node =
std::make_shared<ge::op::Eltwise>(unique_op_type); std::make_shared<ge::op::Eltwise>(unique_op_type);
...@@ -37,7 +37,7 @@ node_map_type ElementwiseConverter( ...@@ -37,7 +37,7 @@ node_map_type ElementwiseConverter(
auto y_var_name = op_info->Input("Y").front(); auto y_var_name = op_info->Input("Y").front();
CHECK_EQ(op_info->GetAttr<int>("axis"), -1) CHECK_EQ(op_info->GetAttr<int>("axis"), -1)
<< "npu elementwise only support inputs with same size"; << "[NPU] elementwise only support inputs with same size";
CHECK(inputs_map.find(x_var_name) != inputs_map.end()); CHECK(inputs_map.find(x_var_name) != inputs_map.end());
elementwise_node->set_input_x1(*inputs_map.at(x_var_name)); elementwise_node->set_input_x1(*inputs_map.at(x_var_name));
...@@ -47,11 +47,11 @@ node_map_type ElementwiseConverter( ...@@ -47,11 +47,11 @@ node_map_type ElementwiseConverter(
elementwise_node->set_input_x2(*inputs_map.at(y_var_name)); elementwise_node->set_input_x2(*inputs_map.at(y_var_name));
lite::npu::OpList::Global().add(inputs_map.at(y_var_name)); lite::npu::OpList::Global().add(inputs_map.at(y_var_name));
} else { } else {
auto consty = std::make_shared<ge::op::Const>(y_var_name); auto y_const_node = std::make_shared<ge::op::Const>(y_var_name);
auto* y = scope->FindVar(y_var_name)->GetMutable<Tensor>(); auto* y = scope->FindVar(y_var_name)->GetMutable<Tensor>();
consty->set_attr_value(lite::npu::CvtFromLiteTensor(y)); y_const_node->set_attr_value(lite::npu::CvtTensor(y));
elementwise_node->set_input_x2(*consty); elementwise_node->set_input_x2(*y_const_node);
lite::npu::OpList::Global().add(consty); lite::npu::OpList::Global().add(y_const_node);
} }
lite::npu::OpList::Global().add(elementwise_node); lite::npu::OpList::Global().add(elementwise_node);
...@@ -60,7 +60,19 @@ node_map_type ElementwiseConverter( ...@@ -60,7 +60,19 @@ node_map_type ElementwiseConverter(
elementwise_node->set_attr_mode(1); elementwise_node->set_attr_mode(1);
node_map_type outputs_map; node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = elementwise_node; if (op_type == "fusion_elementwise_add_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type");
auto act_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/act");
act_node->set_input_x(*elementwise_node);
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
act_node->set_attr_mode(lite::npu::CvtActMode(act_type));
lite::npu::OpList::Global().add(act_node);
outputs_map[op_info->Output("Out").front()] = act_node;
} else {
outputs_map[op_info->Output("Out").front()] = elementwise_node;
}
return outputs_map; return outputs_map;
} }
...@@ -72,3 +84,5 @@ node_map_type ElementwiseConverter( ...@@ -72,3 +84,5 @@ node_map_type ElementwiseConverter(
REGISTER_NPU_BRIDGE(elementwise_add, REGISTER_NPU_BRIDGE(elementwise_add,
paddle::lite::kernels::npu::bridges::ElementwiseConverter); paddle::lite::kernels::npu::bridges::ElementwiseConverter);
REGISTER_NPU_BRIDGE(fusion_elementwise_add_activation,
paddle::lite::kernels::npu::bridges::ElementwiseConverter);
...@@ -27,7 +27,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op, ...@@ -27,7 +27,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
auto op_info = fc_op->op_info(); auto op_info = fc_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto fc_node = std::make_shared<ge::op::FullConnection>(unique_op_type); auto fc_node = std::make_shared<ge::op::FullConnection>(unique_op_type);
...@@ -47,7 +47,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op, ...@@ -47,7 +47,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
int k = x_dims.Slice(in_num_col_dims, x_dims.size()).production(); int k = x_dims.Slice(in_num_col_dims, x_dims.size()).production();
int n = w_dims[1]; int n = w_dims[1];
CHECK_EQ(k * n, w_dims.production()); CHECK_EQ(k * n, w_dims.production());
VLOG(3) << "x dims: " << x_dims << " w dims: " << w_dims << " m: " << m VLOG(3) << "[NPU] x dims: " << x_dims << " w dims: " << w_dims << " m: " << m
<< " k: " << k << " n: " << n; << " k: " << k << " n: " << n;
CHECK(inputs_map.count(x_var_name)); CHECK(inputs_map.count(x_var_name));
...@@ -92,8 +92,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op, ...@@ -92,8 +92,7 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
CHECK_EQ(bias_dims.production(), n); CHECK_EQ(bias_dims.production(), n);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name); auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value( bias_const_node->set_attr_value(lite::npu::CvtTensor(bias, {1, n, 1, 1}));
lite::npu::CvtFromLiteTensor(bias, {1, n, 1, 1}));
fc_node->set_input_b(*bias_const_node); fc_node->set_input_b(*bias_const_node);
lite::npu::OpList::Global().add(bias_const_node); lite::npu::OpList::Global().add(bias_const_node);
} }
......
...@@ -28,7 +28,7 @@ node_map_type InterpolateConverter( ...@@ -28,7 +28,7 @@ node_map_type InterpolateConverter(
auto op_info = interpolate_op->op_info(); auto op_info = interpolate_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
// get input, output and attributes from lite op // get input, output and attributes from lite op
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
...@@ -45,8 +45,9 @@ node_map_type InterpolateConverter( ...@@ -45,8 +45,9 @@ node_map_type InterpolateConverter(
auto out_h = op_info->GetAttr<int>("out_h"); auto out_h = op_info->GetAttr<int>("out_h");
auto align_corners = op_info->GetAttr<bool>("align_corners"); auto align_corners = op_info->GetAttr<bool>("align_corners");
int align_mode = op_info->GetAttr<int>("align_mode"); int align_mode = op_info->GetAttr<int>("align_mode");
CHECK(!(align_mode == 0 && !align_corners)) CHECK(!(align_mode == 0 && !align_corners)) << "[NPU] align_mode = 0 && "
<< "align_mode = 0 && align_corners = false isn't supported in NPU DDK"; "align_corners = false isn't "
"supported in HiAI DDK";
// priority: OutSize > scale > out_h/out_w // priority: OutSize > scale > out_h/out_w
if (scale > 0) { if (scale > 0) {
...@@ -87,9 +88,9 @@ node_map_type InterpolateConverter( ...@@ -87,9 +88,9 @@ node_map_type InterpolateConverter(
const float largest_multiple = 7.0f; const float largest_multiple = 7.0f;
float multiple = static_cast<float>(x_h * x_w) / (out_h * out_w); float multiple = static_cast<float>(x_h * x_w) / (out_h * out_w);
CHECK_LT(multiple, largest_multiple) CHECK_LT(multiple, largest_multiple)
<< "multiple=(ih*iw)/(oh*ow)=" << multiple << "[NPU] multiple=(ih*iw)/(oh*ow)=" << multiple
<< " is too large, should not exceed " << largest_multiple << " is too large, should not exceed " << largest_multiple
<< " in NPU DDK"; << " in HiAI DDK";
auto w_const_node = auto w_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/w"); std::make_shared<ge::op::Const>(unique_op_type + "/w");
w_const_node->set_attr_value( w_const_node->set_attr_value(
...@@ -121,7 +122,7 @@ node_map_type InterpolateConverter( ...@@ -121,7 +122,7 @@ node_map_type InterpolateConverter(
interp_node->set_attr_align_corners(align_corners); interp_node->set_attr_align_corners(align_corners);
outputs_map[op_info->Output("Out").front()] = interp_node; outputs_map[op_info->Output("Out").front()] = interp_node;
} else { } else {
LOG(FATAL) << "unsupported interpolate method: " << interp_method; LOG(FATAL) << "[NPU] Unsupported interpolate method: " << interp_method;
} }
return outputs_map; return outputs_map;
......
...@@ -25,11 +25,13 @@ namespace bridges { ...@@ -25,11 +25,13 @@ namespace bridges {
// handle in this converter // handle in this converter
node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op, node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
const node_map_type& inputs_map) { const node_map_type& inputs_map) {
LOG(INFO) << "converting mul..."; auto scope = mul_op->scope();
lite::Scope* scope = mul_op->scope(); auto op_info = mul_op->op_info();
const lite::OpInfo* op_info = mul_op->op_info(); auto op_type = op_info->Type();
auto output_node = auto unique_op_type = lite::npu::UniqueName(op_type);
std::make_shared<ge::op::MatMul>(lite::npu::UniqueName("mul")); LOG(INFO) << "[NPU] Converting " + op_type + "...";
auto output_node = std::make_shared<ge::op::MatMul>(unique_op_type);
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front(); auto y_var_name = op_info->Input("Y").front();
...@@ -46,7 +48,7 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op, ...@@ -46,7 +48,7 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
int n = ytensor->dims() int n = ytensor->dims()
.Slice(y_num_col_dims, ytensor->dims().size()) .Slice(y_num_col_dims, ytensor->dims().size())
.production(); .production();
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; CHECK_EQ(x_w, y_h) << "[NPU] x_w must be equal with y_h";
int k = x_w; int k = x_w;
LOG(INFO) << "m:" << m << ",n:" << n << ",k:" << k; LOG(INFO) << "m:" << m << ",n:" << n << ",k:" << k;
LOG(INFO) << "x_var_name:" << x_var_name LOG(INFO) << "x_var_name:" << x_var_name
......
...@@ -27,7 +27,7 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op, ...@@ -27,7 +27,7 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op,
auto op_info = pad2d_op->op_info(); auto op_info = pad2d_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Pad> pad2d_node = std::shared_ptr<ge::op::Pad> pad2d_node =
std::make_shared<ge::op::Pad>(unique_op_type); std::make_shared<ge::op::Pad>(unique_op_type);
...@@ -40,10 +40,10 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op, ...@@ -40,10 +40,10 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op,
if (mode == "constant") { if (mode == "constant") {
pad2d_node->set_attr_mode(0); pad2d_node->set_attr_mode(0);
} else if (mode == "reflect") { } else if (mode == "reflect") {
LOG(FATAL) << "NPU doesn't support this pad mod: " << mode; LOG(FATAL) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
pad2d_node->set_attr_mode(1); pad2d_node->set_attr_mode(1);
} else { } else {
LOG(FATAL) << "NPU doesn't support this pad mod: " << mode; LOG(FATAL) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
} }
auto x_dims = scope->FindTensor(x_var_name)->dims(); auto x_dims = scope->FindTensor(x_var_name)->dims();
......
...@@ -23,6 +23,7 @@ USE_NPU_BRIDGE(depthwise_conv2d); ...@@ -23,6 +23,7 @@ USE_NPU_BRIDGE(depthwise_conv2d);
USE_NPU_BRIDGE(pool2d); USE_NPU_BRIDGE(pool2d);
USE_NPU_BRIDGE(relu); USE_NPU_BRIDGE(relu);
USE_NPU_BRIDGE(elementwise_add); USE_NPU_BRIDGE(elementwise_add);
USE_NPU_BRIDGE(fusion_elementwise_add_activation);
USE_NPU_BRIDGE(scale); USE_NPU_BRIDGE(scale);
USE_NPU_BRIDGE(softmax); USE_NPU_BRIDGE(softmax);
USE_NPU_BRIDGE(concat); USE_NPU_BRIDGE(concat);
......
...@@ -27,7 +27,7 @@ node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op, ...@@ -27,7 +27,7 @@ node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op,
auto op_info = pool_op->op_info(); auto op_info = pool_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Pooling> pool_node = std::shared_ptr<ge::op::Pooling> pool_node =
std::make_shared<ge::op::Pooling>(unique_op_type); std::make_shared<ge::op::Pooling>(unique_op_type);
...@@ -39,9 +39,9 @@ node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op, ...@@ -39,9 +39,9 @@ node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op,
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
npu_mode = 1; npu_mode = 1;
CHECK(op_info->GetAttr<bool>("exclusive")) CHECK(op_info->GetAttr<bool>("exclusive"))
<< "exclusive must be true when use npu"; << "[NPU] exclusive must be true in HiAI DDK";
} else { } else {
LOG(FATAL) << "Unsupported pooling type: " << pooling_type; LOG(FATAL) << "[NPU] Unsupported pooling type: " << pooling_type;
} }
bool npu_global_pooling = op_info->GetAttr<bool>("global_pooling"); bool npu_global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize"); auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
......
...@@ -28,7 +28,7 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op, ...@@ -28,7 +28,7 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
auto op_info = reshape_op->op_info(); auto op_info = reshape_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
// get input, output and op attributes // get input, output and op attributes
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
...@@ -55,9 +55,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op, ...@@ -55,9 +55,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
auto out_dims = operators::ValidateShape(shape, x_dims); auto out_dims = operators::ValidateShape(shape, x_dims);
auto out_shape = out_dims.Vectorize(); auto out_shape = out_dims.Vectorize();
if (out_shape.size() > 4) { if (out_shape.size() > 4) {
LOG(WARNING) LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
<< "NPU DDK only supports less than 4 dimensions, but Shape has " "but Shape has "
<< out_shape.size(); << out_shape.size();
} }
auto actual_shape_const_node = auto actual_shape_const_node =
std::make_shared<ge::op::Const>(actual_shape_var_name); std::make_shared<ge::op::Const>(actual_shape_var_name);
...@@ -75,9 +75,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op, ...@@ -75,9 +75,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
auto out_dims = operators::ValidateShape(shape, x_dims); auto out_dims = operators::ValidateShape(shape, x_dims);
auto out_shape = out_dims.Vectorize(); auto out_shape = out_dims.Vectorize();
if (out_shape.size() > 4) { if (out_shape.size() > 4) {
LOG(WARNING) LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
<< "NPU DDK only supports less than 4 dimensions, but shape has " "but shape has "
<< out_shape.size(); << out_shape.size();
} }
reshape_node->set_attr_shape( reshape_node->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end())); ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
...@@ -93,9 +93,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op, ...@@ -93,9 +93,9 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
xshape_dims[i + 1] = x_dims[i]; xshape_dims[i + 1] = x_dims[i];
} }
if (xshape_dims.size() > 4) { if (xshape_dims.size() > 4) {
LOG(WARNING) LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
<< "NPU DDK only supports less than 4 dimensions, but XShape has " "but XShape has "
<< xshape_dims.size(); << xshape_dims.size();
} }
auto xshape_node = auto xshape_node =
std::make_shared<ge::op::Reshape>(unique_op_type + "/xshape"); std::make_shared<ge::op::Reshape>(unique_op_type + "/xshape");
......
...@@ -27,7 +27,7 @@ node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op, ...@@ -27,7 +27,7 @@ node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
auto op_info = scale_op->op_info(); auto op_info = scale_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
// get input, output and op attributes // get input, output and op attributes
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
......
...@@ -28,7 +28,7 @@ node_map_type ShuffleChannelConverter( ...@@ -28,7 +28,7 @@ node_map_type ShuffleChannelConverter(
auto op_info = shuffle_channel_op->op_info(); auto op_info = shuffle_channel_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::ShuffleChannel> shuffle_channel_node = std::shared_ptr<ge::op::ShuffleChannel> shuffle_channel_node =
std::make_shared<ge::op::ShuffleChannel>(unique_op_type); std::make_shared<ge::op::ShuffleChannel>(unique_op_type);
......
...@@ -27,7 +27,7 @@ node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op, ...@@ -27,7 +27,7 @@ node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op,
auto op_info = softmax_op->op_info(); auto op_info = softmax_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Softmax> softmax_node = std::shared_ptr<ge::op::Softmax> softmax_node =
std::make_shared<ge::op::Softmax>(unique_op_type); std::make_shared<ge::op::Softmax>(unique_op_type);
...@@ -37,7 +37,7 @@ node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op, ...@@ -37,7 +37,7 @@ node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op,
auto axis = op_info->GetAttr<int>("axis"); auto axis = op_info->GetAttr<int>("axis");
if (x_dims.size() > 3) { if (x_dims.size() > 3) {
CHECK(!(axis == 2 && x_dims[3] > 1)) CHECK(!(axis == 2 && x_dims[3] > 1))
<< "unsupported npu softmax params: axis = " << axis << "[NPU] Unsupported softmax params: axis = " << axis
<< " :x_w = " << x_dims[3]; << " :x_w = " << x_dims[3];
} }
......
...@@ -73,7 +73,7 @@ void softmax_ref(const std::shared_ptr<operators::SoftmaxOp> op) { ...@@ -73,7 +73,7 @@ void softmax_ref(const std::shared_ptr<operators::SoftmaxOp> op) {
} }
} }
void test_softmax(int bs, int ic, int ih, int iw, int axis) { void test_softmax(const std::vector<int64_t>& input_shape, int axis) {
// prepare input&output variables // prepare input&output variables
Scope scope; Scope scope;
std::string x_var_name = "x"; std::string x_var_name = "x";
...@@ -82,7 +82,7 @@ void test_softmax(int bs, int ic, int ih, int iw, int axis) { ...@@ -82,7 +82,7 @@ void test_softmax(int bs, int ic, int ih, int iw, int axis) {
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>(); auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>(); auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>(); auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw}); x->Resize(input_shape);
// initialize input&output data // initialize input&output data
FillTensor<float>(x); FillTensor<float>(x);
...@@ -111,19 +111,36 @@ void test_softmax(int bs, int ic, int ih, int iw, int axis) { ...@@ -111,19 +111,36 @@ void test_softmax(int bs, int ic, int ih, int iw, int axis) {
} }
TEST(NPUBridges, softmax) { TEST(NPUBridges, softmax) {
for (auto bs : {1, 4, 7}) { test_softmax({1, 4}, -1);
for (auto ic : {1, 4, 7}) { // Bug exists in HiAI DDK when the number of items > 16500
for (auto ih : {1, 4, 7}) { // test_softmax({1, 16500}, -1);
for (auto iw : {1, 4, 7}) { test_softmax({1, 4}, 0);
for (auto axis : {-3, -1, 0, 1, 2, 3}) { test_softmax({1, 4}, 1);
// npu softmax exists bugs when axis is 2 and iw > 1 test_softmax({3, 4}, -1);
if (axis == 2 && iw > 1) continue; test_softmax({3, 4}, 0);
test_softmax(bs, ic, ih, iw, axis); test_softmax({3, 4}, 1);
} test_softmax({1, 4, 7}, -1);
} test_softmax({1, 4, 7}, 0);
} // Bug exists in HiAI DDK when axis is 1 and iw > 1
} // test_softmax({1, 4, 7}, 1);
} test_softmax({1, 4, 1}, 1);
test_softmax({1, 4, 7}, 2);
test_softmax({3, 4, 7}, -1);
test_softmax({3, 4, 7}, 0);
test_softmax({3, 4, 1}, 1);
test_softmax({3, 4, 7}, 2);
test_softmax({1, 4, 7, 9}, -1);
test_softmax({1, 4, 7, 9}, 0);
test_softmax({1, 4, 7, 9}, 1);
// Bug exists in HiAI DDK when axis is 2 and iw > 1
// test_softmax({1, 4, 7, 9}, 2);
test_softmax({1, 4, 7, 1}, 2);
test_softmax({1, 4, 7, 9}, 3);
test_softmax({3, 4, 7, 9}, -1);
test_softmax({3, 4, 7, 9}, 0);
test_softmax({3, 4, 7, 9}, 1);
test_softmax({3, 4, 7, 1}, 2);
test_softmax({3, 4, 7, 9}, 3);
} }
} // namespace bridges } // namespace bridges
......
...@@ -27,7 +27,7 @@ node_map_type SplitConverter(const std::shared_ptr<lite::OpLite> split_op, ...@@ -27,7 +27,7 @@ node_map_type SplitConverter(const std::shared_ptr<lite::OpLite> split_op,
const lite::OpInfo* op_info = split_op->op_info(); const lite::OpInfo* op_info = split_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << " ... "; LOG(INFO) << "[NPU] Converting " << op_type << " ... ";
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
auto axis = op_info->GetAttr<int>("axis"); auto axis = op_info->GetAttr<int>("axis");
......
...@@ -28,7 +28,7 @@ node_map_type TransposeConverter( ...@@ -28,7 +28,7 @@ node_map_type TransposeConverter(
auto op_info = transpose_op->op_info(); auto op_info = transpose_op->op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type); auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "..."; LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Permute> transpose_node = std::shared_ptr<ge::op::Permute> transpose_node =
std::make_shared<ge::op::Permute>(unique_op_type); std::make_shared<ge::op::Permute>(unique_op_type);
...@@ -44,7 +44,7 @@ node_map_type TransposeConverter( ...@@ -44,7 +44,7 @@ node_map_type TransposeConverter(
w_data[i] = 1.f; w_data[i] = 1.f;
} }
auto npu_w = std::make_shared<ge::op::Const>(w_var_name); auto npu_w = std::make_shared<ge::op::Const>(w_var_name);
npu_w->set_attr_value(lite::npu::CvtFromLiteTensor(w)); npu_w->set_attr_value(lite::npu::CvtTensor(w));
lite::npu::OpList::Global().add(npu_w); lite::npu::OpList::Global().add(npu_w);
auto axis = op_info->GetAttr<std::vector<int>>("axis"); auto axis = op_info->GetAttr<std::vector<int>>("axis");
......
...@@ -15,11 +15,6 @@ ...@@ -15,11 +15,6 @@
#include "lite/kernels/npu/graph_compute.h" #include "lite/kernels/npu/graph_compute.h"
#include <sys/time.h> #include <sys/time.h>
#include <time.h> #include <time.h>
#include <string>
#include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -30,6 +25,8 @@ void GraphCompute::PrepareForRun() { ...@@ -30,6 +25,8 @@ void GraphCompute::PrepareForRun() {
auto& ctx = this->ctx_->template As<NPUContext>(); auto& ctx = this->ctx_->template As<NPUContext>();
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
// Load HiAI model from the weight tensor and release its buffer
// to save memory
CHECK(param.weight); CHECK(param.weight);
CHECK(lite::npu::LoadModel(*param.weight, &model_client_, &model_name_)); CHECK(lite::npu::LoadModel(*param.weight, &model_client_, &model_name_));
// TODO(hong19860320): find an good way to free the model data. // TODO(hong19860320): find an good way to free the model data.
...@@ -38,84 +35,75 @@ void GraphCompute::PrepareForRun() { ...@@ -38,84 +35,75 @@ void GraphCompute::PrepareForRun() {
param.weight->Resize({1}); param.weight->Resize({1});
param.weight->mutable_data<int8_t>(TargetType::kARM); param.weight->mutable_data<int8_t>(TargetType::kARM);
CHECK(model_client_); CHECK(model_client_);
int ret =
model_client_->GetModelIOTensorDim(model_name_, npu_idims_, npu_odims_);
CHECK_EQ(ret, hiai::AI_SUCCESS) << "[NPU] Get dims failed.";
npu_itensors_.resize(npu_idims_.size());
npu_otensors_.resize(npu_odims_.size());
for (size_t i = 0; i < npu_idims_.size(); ++i) { // Query the dimensions of NPU input and output tensors from HiAI model
VLOG(3) << "npu_idims[" << i << "]: " << npu_idims_[i].GetNumber() << "," std::vector<hiai::TensorDimension> npu_idims;
<< npu_idims_[i].GetChannel() << "," << npu_idims_[i].GetHeight() std::vector<hiai::TensorDimension> npu_odims;
<< "," << npu_idims_[i].GetWidth(); int ret =
VLOG(3) << "lite_idims[" << i << "]: " << param.inputs[i].second->dims(); model_client_->GetModelIOTensorDim(model_name_, npu_idims, npu_odims);
CHECK_EQ(param.inputs[i].second->dims().production(), CHECK_EQ(ret, hiai::AI_SUCCESS)
npu_idims_[i].GetNumber() * npu_idims_[i].GetChannel() * << "[NPU] Get the dimensions of input and output tensors failed.";
npu_idims_[i].GetHeight() * npu_idims_[i].GetWidth());
// Check whether the data sizes of NPU input and output tensors are the
// same as CPU's, then create and initialize NPU input and output tensors.
npu_itensors_.resize(npu_idims.size());
npu_otensors_.resize(npu_odims.size());
npu_idatasizes_.resize(npu_idims.size());
npu_odatasizes_.resize(npu_odims.size());
for (size_t i = 0; i < npu_idims.size(); ++i) {
auto cpu_itensor = param.inputs[i].second;
CHECK(cpu_itensor);
VLOG(3) << "[NPU] CPU input dims[" << i << "]: " << cpu_itensor->dims();
VLOG(3) << "[NPU] NPU input dims[" << i << "]: {"
<< npu_idims[i].GetNumber() << "," << npu_idims[i].GetChannel()
<< "," << npu_idims[i].GetHeight() << "," << npu_idims[i].GetWidth()
<< "}";
npu_idatasizes_[i] = npu_idims[i].GetNumber() * npu_idims[i].GetChannel() *
npu_idims[i].GetHeight() * npu_idims[i].GetWidth();
CHECK_EQ(cpu_itensor->dims().production(), npu_idatasizes_[i]);
npu_itensors_[i].reset(new hiai::AiTensor); npu_itensors_[i].reset(new hiai::AiTensor);
npu_itensors_[i]->Init(&(npu_idims_[i])); npu_itensors_[i]->Init(&(npu_idims[i]));
} }
for (size_t i = 0; i < npu_odims.size(); ++i) {
for (size_t i = 0; i < npu_odims_.size(); ++i) { auto cpu_otensor = param.outputs[i].second;
VLOG(3) << "npu_odims[" << i << "]: " << npu_odims_[i].GetNumber() << "," CHECK(cpu_otensor);
<< npu_odims_[i].GetChannel() << "," << npu_odims_[i].GetHeight() VLOG(3) << "[NPU] CPU output dims[" << i << "]: " << cpu_otensor->dims();
<< "," << npu_odims_[i].GetWidth(); VLOG(3) << "[NPU] NPU output dims[" << i << "]: {"
VLOG(3) << "lite_odims[" << i << "]: " << param.outputs[i].second->dims(); << npu_odims[i].GetNumber() << "," << npu_odims[i].GetChannel()
auto out_size = npu_odims_[i].GetNumber() * npu_odims_[i].GetChannel() * << "," << npu_odims[i].GetHeight() << "," << npu_odims[i].GetWidth()
npu_odims_[i].GetHeight() * npu_odims_[i].GetWidth(); << "}";
if (param.outputs[i].second->dims().production() != out_size) { npu_odatasizes_[i] = npu_odims[i].GetNumber() * npu_odims[i].GetChannel() *
param.outputs[i].second->Resize({npu_odims_[i].GetNumber(), npu_odims[i].GetHeight() * npu_odims[i].GetWidth();
npu_odims_[i].GetChannel(), if (cpu_otensor->dims().production() != npu_odatasizes_[i]) {
npu_odims_[i].GetHeight(), cpu_otensor->Resize({npu_odims[i].GetNumber(),
npu_odims_[i].GetWidth()}); npu_odims[i].GetChannel(),
npu_odims[i].GetHeight(),
npu_odims[i].GetWidth()});
} }
LOG(INFO) << param.outputs[i].second->dims();
npu_otensors_[i].reset(new hiai::AiTensor); npu_otensors_[i].reset(new hiai::AiTensor);
npu_otensors_[i]->Init(&(npu_odims_[i])); npu_otensors_[i]->Init(&(npu_odims[i]));
}
}
bool GraphCompute::input_dims_changed() const {
auto& param = this->Param<param_t>();
CHECK_EQ(param.inputs.size(), npu_idims_.size());
for (size_t i = 0; i < param.inputs.size(); ++i) {
auto param_idims = param.inputs[i].second->dims();
CHECK(!param_idims.empty());
CHECK_EQ(param_idims.size(), 4);
std::vector<int> idims{static_cast<int>(npu_idims_[i].GetNumber()),
static_cast<int>(npu_idims_[i].GetChannel()),
static_cast<int>(npu_idims_[i].GetHeight()),
static_cast<int>(npu_idims_[i].GetWidth())};
for (size_t i = 0; i < 4; ++i) {
if (param_idims[i] != idims[i]) {
return true;
}
}
} }
return false;
} }
void GraphCompute::Run() { void GraphCompute::Run() {
CHECK(!input_dims_changed())
<< "When NPU is enabled, the input shape could not be changed yet.";
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
// Check whether the data sizes of NPU input tensors are the same as
// CPU's, and copy the data of CPU input tensors to NPU's.
CHECK_EQ(param.inputs.size(), npu_itensors_.size()); CHECK_EQ(param.inputs.size(), npu_itensors_.size());
CHECK_EQ(param.outputs.size(), npu_otensors_.size()); CHECK_EQ(param.outputs.size(), npu_otensors_.size());
for (size_t i = 0; i < param.inputs.size(); ++i) { for (size_t i = 0; i < param.inputs.size(); ++i) {
auto* itensor = param.inputs[i].second; auto cpu_itensor = param.inputs[i].second;
CHECK(itensor); CHECK(cpu_itensor);
const auto* i_data = itensor->data<float>(); CHECK_EQ(cpu_itensor->dims().production(), npu_idatasizes_[i]);
std::memcpy( std::memcpy(static_cast<float*>(npu_itensors_[i]->GetBuffer()),
npu_itensors_[i]->GetBuffer(), cpu_itensor->data<float>(),
i_data, sizeof(float) * static_cast<size_t>(npu_idatasizes_[i]));
sizeof(float) * static_cast<size_t>(itensor->dims().production()));
} }
// Run HiAI model with model name
std::string key = "model_name"; // Note: key seems must be model_name std::string key = "model_name"; // Note: key seems must be model_name
model_context_.AddPara(key, model_name_); model_context_.AddPara(key, model_name_);
auto GetCurrentUS = []() -> double { auto GetCurrentUS = []() -> double {
struct timeval time; struct timeval time;
gettimeofday(&time, NULL); gettimeofday(&time, NULL);
...@@ -128,16 +116,15 @@ void GraphCompute::Run() { ...@@ -128,16 +116,15 @@ void GraphCompute::Run() {
model_context_, npu_itensors_, npu_otensors_, 1000, istamp)); model_context_, npu_itensors_, npu_otensors_, 1000, istamp));
VLOG(3) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us"; VLOG(3) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us";
// Check whether the data sizes of NPU output tensors are the same as
// CPU's, and copy the data of NPU output tensors to CPU's.
for (size_t i = 0; i < param.outputs.size(); ++i) { for (size_t i = 0; i < param.outputs.size(); ++i) {
auto* otensor = param.outputs[i].second; auto cpu_otensor = param.outputs[i].second;
CHECK(otensor); CHECK(cpu_otensor);
auto* o_data = otensor->mutable_data<float>(); CHECK_EQ(cpu_otensor->dims().production(), npu_odatasizes_[i]);
auto* npu_obuffer = static_cast<float*>(npu_otensors_[i]->GetBuffer()); std::memcpy(cpu_otensor->mutable_data<float>(),
static_cast<float*>(npu_otensors_[i]->GetBuffer()),
std::memcpy( sizeof(float) * static_cast<size_t>(npu_odatasizes_[i]));
o_data,
npu_obuffer,
sizeof(float) * static_cast<size_t>(otensor->dims().production()));
} }
} }
......
...@@ -37,16 +37,13 @@ class GraphCompute : public KernelLite<TARGET(kNPU), PRECISION(kFloat)> { ...@@ -37,16 +37,13 @@ class GraphCompute : public KernelLite<TARGET(kNPU), PRECISION(kFloat)> {
virtual ~GraphCompute() = default; virtual ~GraphCompute() = default;
bool input_dims_changed() const;
private: private:
std::shared_ptr<hiai::AiModelMngerClient> model_client_; std::shared_ptr<hiai::AiModelMngerClient> model_client_;
std::string model_name_; std::string model_name_;
hiai::AiContext model_context_; hiai::AiContext model_context_;
std::vector<hiai::TensorDimension> npu_idims_; std::vector<int64_t> npu_idatasizes_;
std::vector<hiai::TensorDimension> npu_odims_; std::vector<int64_t> npu_odatasizes_;
std::vector<std::shared_ptr<hiai::AiTensor>> npu_itensors_; std::vector<std::shared_ptr<hiai::AiTensor>> npu_itensors_;
std::vector<std::shared_ptr<hiai::AiTensor>> npu_otensors_; std::vector<std::shared_ptr<hiai::AiTensor>> npu_otensors_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册