diff --git a/tools/onnx/onnx2ncnn.cpp b/tools/onnx/onnx2ncnn.cpp index 6696ee7981733aa69d486921a40ba1a13e43d633..9ce5ec403286c1dd47c4da175ebb7a33abc480dc 100644 --- a/tools/onnx/onnx2ncnn.cpp +++ b/tools/onnx/onnx2ncnn.cpp @@ -152,6 +152,75 @@ static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const return onnx::TensorProto(); } +static float get_node_attr_from_input_f(const onnx::TensorProto& tp) +{ + float v = 0.f; + + // float + if (tp.data_type() == 1) + { + const float* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const float*)tp.raw_data().data(); + } + else + { + shape_data = tp.float_data().data(); + } + v = shape_data[0]; + } + // double + else if (tp.data_type() == 11) + { + const double* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const double*)tp.raw_data().data(); + } + else + { + shape_data = tp.double_data().data(); + } + v = shape_data[0]; + } + // int64 + else if (tp.data_type() == 7) + { + const int64_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int64_t*)tp.raw_data().data(); + } + else + { + shape_data = tp.int64_data().data(); + } + v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN); + } + // int32 + else if (tp.data_type() == 6) + { + const int32_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int32_t*)tp.raw_data().data(); + } + else + { + shape_data = tp.int32_data().data(); + } + v = shape_data[0]; + } + else + { + fprintf(stderr, "Unknown data type %d\n", tp.data_type()); + abort(); + } + + return v; +} + static std::vector get_node_attr_from_input_ai(const onnx::TensorProto& tp) { int size = 0; @@ -288,25 +357,70 @@ static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) } } -static void fuse_matmul(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); - // MatMul <= Transpose(weight) - MatMul - if (node->op_type() == "Transpose") + // weight <= Reshape(weight) + if (node->op_type() == "Reshape") { // check weight if (weights.find(node->input(0)) == weights.end()) continue; - onnx::TensorProto& B = weights[node->input(0)]; - if (B.dims_size() != 2) + weights[node->output(0)] = weights[node->input(0)]; + + // set weight shape directly + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else if (node->input_size() == 2) + { + // opset 5 + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } + + weights[node->output(0)].clear_dims(); + for (int j = 0; j < shape.size(); j++) + { + weights[node->output(0)].add_dims(shape[j]); + } + + // reduce + node->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + + reduced_node_count += 1; + i += 1; + } + } +} + +static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // weight <= Transpose(weight) + if (node->op_type() == "Transpose") + { + // check weight + if (weights.find(node->input(0)) == weights.end()) continue; - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (weights[node->input(0)].dims_size() != 2) continue; // perm = (1, 0) @@ -316,24 +430,12 @@ static void fuse_matmul(onnx::GraphProto* mutable_graph, std::map= node_count) - continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - - if (node2->op_type() != "MatMul") - continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - - node_reference.erase(node_reference.find(node->output(0))); - blob_names.erase(node->output(0)); - - node2->set_input(1, node->input(0)); + weights[node->output(0)] = weights[node->input(0)]; // permute weight { + onnx::TensorProto& B = weights[node->output(0)]; + const int h = B.dims(0); const int w = B.dims(1); @@ -364,13 +466,18 @@ static void fuse_matmul(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + reduced_node_count += 1; i += 1; } } } -static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -381,7 +488,7 @@ static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::mapop_type() == "Reshape") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; std::vector shape; @@ -423,7 +530,7 @@ static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::mapop_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; // 0 2 1 3 4 @@ -467,8 +574,17 @@ static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); @@ -489,7 +605,7 @@ static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -546,8 +662,10 @@ static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map< node2->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node2->output(0)] -= 1; + node_reference[node3->input(1)] -= 1; - node_reference.erase(node_reference.find(node2->output(0))); blob_names.erase(node2->output(0)); node3->set_op_type("Split"); @@ -567,7 +685,7 @@ static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map< } } -static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -581,20 +699,20 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::mapop_type() == "Add") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; if (i + 3 >= node_count) continue; - if (binaryop_weights.find(node->input(1)) == binaryop_weights.end()) + if (weights.find(node->input(1)) == weights.end()) continue; - const onnx::TensorProto& add_three = binaryop_weights[node->input(1)]; + const onnx::TensorProto& add_three = weights[node->input(1)]; if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; - float constant_add_three = add_three.has_raw_data() ? ((const float*)add_three.raw_data().data())[0] : add_three.float_data().data()[0]; + float constant_add_three = get_node_attr_from_input_f(add_three); if (constant_add_three != 3.f) continue; @@ -613,7 +731,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::mapop_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul")) continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; float relu6_min; @@ -627,29 +745,27 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::mapinput(1)]; const onnx::TensorProto& max_tp = weights[node2->input(2)]; - const float* min_data = min_tp.has_raw_data() ? (const float*)min_tp.raw_data().data() : min_tp.float_data().data(); - const float* max_data = max_tp.has_raw_data() ? (const float*)max_tp.raw_data().data() : max_tp.float_data().data(); - relu6_min = min_data[0]; - relu6_max = max_data[0]; + relu6_min = get_node_attr_from_input_f(min_tp); + relu6_max = get_node_attr_from_input_f(max_tp); } if (relu6_min != 0.f || relu6_max != 6.f) continue; - if (node_reference.find(node3->output(0)) == node_reference.end() || node_reference[node3->output(0)] != 1) + if (node_reference[node3->output(0)] != 1) continue; if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0)) continue; - if (binaryop_weights.find(node4->input(1)) == binaryop_weights.end()) + if (weights.find(node4->input(1)) == weights.end()) continue; - const onnx::TensorProto& div_six = binaryop_weights[node4->input(1)]; + const onnx::TensorProto& div_six = weights[node4->input(1)]; if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; - float constant_div_six = div_six.has_raw_data() ? ((const float*)div_six.raw_data().data())[0] : div_six.float_data().data()[0]; + float constant_div_six = get_node_attr_from_input_f(div_six); if (node4->op_type() == "Div" && constant_div_six != 6.f) continue; if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) @@ -661,17 +777,21 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; + node_reference[node->input(1)] -= 1; + node_reference[node->output(0)] -= 1; + if (node2->input_size() == 3) + { + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + } + node_reference[node2->output(0)] -= 1; + node_reference[node3->output(0)] -= 1; + node_reference[node4->input(1)] -= 1; - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); - node_reference.erase(node_reference.find(node3->output(0))); blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); blob_names.erase(node3->output(0)); - reduced_binaryop_weights.push_back(node->input(1)); - reduced_binaryop_weights.push_back(node4->input(1)); - node4->set_op_type("HardSwish"); node4->clear_input(); node4->add_input(node->input(0)); @@ -697,7 +817,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::mapop_type() == "HardSigmoid") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; float alpha = get_node_attr_f(*node, "alpha", 0.2f); @@ -718,8 +838,8 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; - node_reference.erase(node_reference.find(node->output(0))); blob_names.erase(node->output(0)); node2->set_op_type("HardSwish"); @@ -740,7 +860,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -754,20 +874,20 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::mapop_type() == "Add") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; if (i + 2 >= node_count) continue; - if (binaryop_weights.find(node->input(1)) == binaryop_weights.end()) + if (weights.find(node->input(1)) == weights.end()) continue; - const onnx::TensorProto& add_three = binaryop_weights[node->input(1)]; + const onnx::TensorProto& add_three = weights[node->input(1)]; if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; - float constant_add_three = add_three.has_raw_data() ? ((const float*)add_three.raw_data().data())[0] : add_three.float_data().data()[0]; + float constant_add_three = get_node_attr_from_input_f(add_three); if (constant_add_three != 3.f) continue; @@ -785,7 +905,7 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::mapop_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul")) continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; float relu6_min; @@ -799,23 +919,21 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::mapinput(1)]; const onnx::TensorProto& max_tp = weights[node2->input(2)]; - const float* min_data = min_tp.has_raw_data() ? (const float*)min_tp.raw_data().data() : min_tp.float_data().data(); - const float* max_data = max_tp.has_raw_data() ? (const float*)max_tp.raw_data().data() : max_tp.float_data().data(); - relu6_min = min_data[0]; - relu6_max = max_data[0]; + relu6_min = get_node_attr_from_input_f(min_tp); + relu6_max = get_node_attr_from_input_f(max_tp); } if (relu6_min != 0.f || relu6_max != 6.f) continue; - if (binaryop_weights.find(node3->input(1)) == binaryop_weights.end()) + if (weights.find(node3->input(1)) == weights.end()) continue; - const onnx::TensorProto& div_six = binaryop_weights[node3->input(1)]; + const onnx::TensorProto& div_six = weights[node3->input(1)]; if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; - float constant_div_six = div_six.has_raw_data() ? ((const float*)div_six.raw_data().data())[0] : div_six.float_data().data()[0]; + float constant_div_six = get_node_attr_from_input_f(div_six); if (node3->op_type() == "Div" && constant_div_six != 6.f) continue; if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) @@ -825,14 +943,19 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); + node_reference[node->input(1)] -= 1; + node_reference[node->output(0)] -= 1; + if (node2->input_size() == 3) + { + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + } + node_reference[node2->output(0)] -= 1; + node_reference[node3->input(1)] -= 1; + blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); - reduced_binaryop_weights.push_back(node->input(1)); - reduced_binaryop_weights.push_back(node3->input(1)); - node3->set_op_type("HardSigmoid"); node3->clear_input(); node3->add_input(node->input(0)); @@ -851,7 +974,7 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -862,7 +985,7 @@ static void fuse_swish(onnx::GraphProto* mutable_graph, std::mapop_type() == "Sigmoid") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; if (i + 1 >= node_count) @@ -880,8 +1003,8 @@ static void fuse_swish(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; - node_reference.erase(node_reference.find(node->output(0))); blob_names.erase(node->output(0)); node2->set_op_type("Swish"); @@ -894,7 +1017,7 @@ static void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -904,7 +1027,7 @@ static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze if (node->op_type() == "Unsqueeze") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; if (i + 2 >= node_count) @@ -916,7 +1039,7 @@ static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze") continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) @@ -926,8 +1049,9 @@ static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, node->set_op_type("noop_reducedncnn"); node3->set_op_type("noop_reducedncnn"); - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); @@ -940,7 +1064,7 @@ static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, } } -static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -958,7 +1082,7 @@ static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::mapoutput(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; // axes = (1, 2) @@ -982,7 +1106,8 @@ static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); - node_reference.erase(node_reference.find(node->output(0))); + node_reference[node->output(0)] -= 1; + blob_names.erase(node->output(0)); node2->set_input(1, node->input(0)); @@ -993,7 +1118,7 @@ static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -1004,7 +1129,7 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, std::mapop_type() == "ReduceL2") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; // axes = (1) @@ -1036,10 +1161,10 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, std::mapop_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div") continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; - if (node_reference.find(node3->output(0)) == node_reference.end() || node_reference[node3->output(0)] != 1) + if (node_reference[node3->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) @@ -1061,9 +1186,8 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, std::mapinput(1)]; - const float* min_data = min_tp.has_raw_data() ? (const float*)min_tp.raw_data().data() : min_tp.float_data().data(); - clip_min = min_data[0]; + clip_min = get_node_attr_from_input_f(min_tp); } // reduce @@ -1076,14 +1200,14 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= has_shape_node ? 2 : 1; - - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; if (has_shape_node) { - node_reference.erase(node_reference.find(node_shape->output(0))); + node_reference[node_shape->output(0)] -= 1; } - node_reference.erase(node_reference.find(node3->output(0))); + node_reference[node3->output(0)] -= 1; + blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); if (has_shape_node) @@ -1106,7 +1230,7 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -1116,7 +1240,7 @@ static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::mapop_type() == "Reshape") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; std::vector shape; @@ -1153,13 +1277,13 @@ static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::mapop_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add") continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; - if (node_reference.find(node3->output(0)) == node_reference.end() || node_reference[node3->output(0)] != 1) + if (node_reference[node3->output(0)] != 1) continue; - if (node_reference.find(node4->output(0)) == node_reference.end() || node_reference[node4->output(0)] != 1) + if (node_reference[node4->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) @@ -1213,8 +1337,8 @@ static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map affine_S = get_node_attr_from_input_af(binaryop_weights[node4->input(1)]); - std::vector affine_B = get_node_attr_from_input_af(binaryop_weights[node5->input(1)]); + std::vector affine_S = get_node_attr_from_input_af(weights[node4->input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node5->input(1)]); if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f) { affine = 0; @@ -1240,36 +1364,31 @@ static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node4->set_op_type("noop_reducedncnn"); - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); - node_reference.erase(node_reference.find(node3->output(0))); - node_reference.erase(node_reference.find(node4->output(0))); - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + node_reference[node3->output(0)] -= 1; + node_reference[node4->output(0)] -= 1; std::string affine_scale = node4->input(1); std::string affine_bias = node5->input(1); - if (affine) - { - weights[affine_scale] = binaryop_weights[affine_scale]; - weights[affine_bias] = binaryop_weights[affine_bias]; - - binaryop_weights.erase(binaryop_weights.find(affine_scale)); - binaryop_weights.erase(binaryop_weights.find(affine_bias)); + node_reference[affine_scale] -= 1; + node_reference[affine_bias] -= 1; - node_reference.erase(node_reference.find(affine_scale)); - node_reference.erase(node_reference.find(affine_bias)); - blob_names.erase(affine_scale); - blob_names.erase(affine_bias); - } - else - { - reduced_binaryop_weights.push_back(affine_scale); - reduced_binaryop_weights.push_back(affine_bias); - } + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); node5->set_op_type("GroupNorm"); node5->clear_input(); @@ -1302,7 +1421,7 @@ static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -1312,7 +1431,7 @@ static void fuse_flatten(onnx::GraphProto* mutable_graph, std::mapop_type() == "Shape") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; if (i + 6 >= node_count) @@ -1329,19 +1448,19 @@ static void fuse_flatten(onnx::GraphProto* mutable_graph, std::mapop_type() != "Concat" || node7->op_type() != "Reshape") continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; - // if (node_reference.find(node3->output(0)) == node_reference.end() || node_reference[node3->output(0)] != 1) + // if (node_reference[node3->output(0)] != 1) // continue; - if (node_reference.find(node4->output(0)) == node_reference.end() || node_reference[node4->output(0)] != 1) + if (node_reference[node4->output(0)] != 1) continue; - if (node_reference.find(node5->output(0)) == node_reference.end() || node_reference[node5->output(0)] != 1) + if (node_reference[node5->output(0)] != 1) continue; - if (node_reference.find(node6->output(0)) == node_reference.end() || node_reference[node6->output(0)] != 1) + if (node_reference[node6->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0) @@ -1398,13 +1517,15 @@ static void fuse_flatten(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node2->output(0)] -= 1; + // node_reference[node3->output(0)] -= 1; + node_reference[node4->output(0)] -= 1; + node_reference[node5->input(0)] -= 1; + node_reference[node5->output(0)] -= 1; + node_reference[node6->output(0)] -= 1; - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); - // node_reference.erase(node_reference.find(node3->output(0))); - node_reference.erase(node_reference.find(node4->output(0))); - node_reference.erase(node_reference.find(node5->output(0))); - node_reference.erase(node_reference.find(node6->output(0))); blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); // blob_names.erase(node3->output(0)); @@ -1422,7 +1543,7 @@ static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -1433,7 +1554,7 @@ static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::mapop_type() == "Reshape") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; std::vector shape; @@ -1477,7 +1598,7 @@ static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::mapop_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; // 0 1 4 2 5 3 @@ -1516,8 +1637,17 @@ static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); @@ -1534,7 +1664,7 @@ static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -1545,7 +1675,7 @@ static void fuse_reorg(onnx::GraphProto* mutable_graph, std::mapop_type() == "Reshape") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; std::vector shape; @@ -1589,7 +1719,7 @@ static void fuse_reorg(onnx::GraphProto* mutable_graph, std::mapop_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference.find(node2->output(0)) == node_reference.end() || node_reference[node2->output(0)] != 1) + if (node_reference[node2->output(0)] != 1) continue; // 0 1 3 5 2 4 @@ -1628,8 +1758,17 @@ static void fuse_reorg(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); - node_reference.erase(node_reference.find(node->output(0))); - node_reference.erase(node_reference.find(node2->output(0))); + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); @@ -1646,7 +1785,7 @@ static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, std::map& binaryop_weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count, std::vector& reduced_binaryop_weights) +static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) @@ -1656,7 +1795,7 @@ static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::mapop_type() == "Expand") { - if (node_reference.find(node->output(0)) == node_reference.end() || node_reference[node->output(0)] != 1) + if (node_reference[node->output(0)] != 1) continue; if (i + 1 >= node_count) @@ -1673,7 +1812,8 @@ static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::mapset_op_type("noop_reducedncnn"); - node_reference.erase(node_reference.find(node->output(0))); + node_reference[node->output(0)] -= 1; + blob_names.erase(node->output(0)); node2->set_input(1, node->input(0)); @@ -1717,14 +1857,11 @@ int main(int argc, char** argv) // weight node and weight reshape node std::map weights; - // weight node before BinaryOp - std::map binaryop_weights; - for (int j = 0; j < graph.initializer_size(); j++) { const onnx::TensorProto& initializer = graph.initializer(j); - // fprintf(stderr, "weight = %s\n", initializer.name().c_str()); + // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type()); weights[initializer.name()] = initializer; } @@ -1748,145 +1885,12 @@ int main(int argc, char** argv) { onnx::TensorProto tensor = get_node_attr_tensor(node, "value"); weights[node.output(0)] = tensor; - continue; - } - else if (op == "Reshape") - { - if (node.input_size() == 1) - { - const std::string& input_name = node.input(0); - - // check weight - if (weights.find(input_name) != weights.end()) - { - weights[node.output(0)] = weights[input_name]; - continue; - } - } - else if (node.input_size() == 2) - { - // opset 5 - const std::string& input_name = node.input(0); - - // check weight - if (weights.find(input_name) != weights.end()) - { - weights[node.output(0)] = weights[input_name]; - - // set weight shape directly - const onnx::TensorProto& shape_tp = weights[node.input(1)]; - const int64_t* shape_data = shape_tp.int64_data().data(); - - weights[node.output(0)].clear_dims(); - for (int j = 0; j < shape_tp.int64_data_size(); j++) - { - weights[node.output(0)].add_dims(shape_data[j]); - } - - continue; - } - } - } - else if (op == "Gemm") - { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 1.f); - int transA = get_node_attr_i(node, "transA", 0); - int transB = get_node_attr_i(node, "transB", 0); - - if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) - { - // InnerProduct-like A * B + C - } - else - { - // gemm A - const std::string& A = node.input(0); - std::map::iterator itA = weights.find(A); - if (itA != weights.end()) - { - binaryop_weights[A] = itA->second; - weights.erase(itA); - } - - // gemm B - const std::string& B = node.input(1); - std::map::iterator itB = weights.find(B); - if (itB != weights.end()) - { - binaryop_weights[B] = itB->second; - weights.erase(itB); - } - - if (node.input_size() == 3) - { - // gemm C - const std::string& C = node.input(2); - std::map::iterator itC = weights.find(C); - if (itC != weights.end()) - { - binaryop_weights[C] = itC->second; - weights.erase(itC); - } - } - } - } - else if (op == "MatMul") - { - // gemm A - const std::string& A = node.input(0); - std::map::iterator itA = weights.find(A); - if (itA != weights.end()) - { - binaryop_weights[A] = itA->second; - weights.erase(itA); - } - - // gemm B can be weight when rank2 - const std::string& B = node.input(1); - std::map::iterator itB = weights.find(B); - if (itB != weights.end() && itB->second.dims_size() != 2) - { - binaryop_weights[B] = itB->second; - weights.erase(itB); - } - } - else - { - bool isBinaryOp = false; - if (op == "Add" || op == "Sub" || op == "Mul" || op == "Div" || op == "Max" || op == "Min" || op == "Pow") - { - isBinaryOp = true; - } - - if (isBinaryOp) - { - // check weights - for (int j = 0; j < node.input_size(); j++) - { - const std::string& input_name = node.input(j); - - std::map::iterator it = weights.find(input_name); - if (it != weights.end()) - { - // binary op with weight, insert MemoryData layer and const blob - binaryop_weights[input_name] = it->second; - weights.erase(it); - } - } - } } for (int j = 0; j < (int)node.input_size(); j++) { const std::string& input_name = node.input(j); - // check weight - if (weights.find(input_name) != weights.end()) - { - continue; - } - blob_names.insert(input_name); if (node_reference.find(input_name) == node_reference.end()) @@ -1909,6 +1913,7 @@ int main(int argc, char** argv) { const std::string& output_name = node.output(0); blob_names.insert(output_name); + node_reference[output_name] = 0; continue; } @@ -1916,6 +1921,7 @@ int main(int argc, char** argv) { const std::string& output_name = node.output(0); blob_names.insert(output_name); + node_reference[output_name] = 0; continue; } @@ -1924,6 +1930,8 @@ int main(int argc, char** argv) const std::string& output_name = node.output(j); blob_names.insert(output_name); + + node_reference[output_name] = 0; } } @@ -1937,48 +1945,193 @@ int main(int argc, char** argv) if (weights.find(input_name) != weights.end()) continue; - // check weight before BinaryOp - if (binaryop_weights.find(input_name) != binaryop_weights.end()) - continue; - blob_names.insert(input_name); input_node_count++; } + // for (auto a: node_reference) + // { + // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second); + // } + // op chain fusion int reduced_node_count = 0; - std::vector reduced_binaryop_weights; - fuse_matmul(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_shufflechannel(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_shufflechannel_split(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_hardsigmoid(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_hardswish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_swish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_unsqueeze_prelu(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_normalize(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_groupnorm(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_flatten(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_pixelshuffle(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_reorg(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - fuse_expand_broadcast(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); - - // remove node_reference entry with reference equals to one - int splitncnn_blob_count = 0; - std::map::iterator it = node_reference.begin(); - while (it != node_reference.end()) + fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + + // reduce common const weight node_reference + for (int i = 0; i < node_count; i++) { - if (it->second == 1) + const onnx::NodeProto& node = graph.node(i); + + const std::string& op = node.op_type(); + + if (op == "Add" || op == "Sub" || op == "Mul" || op == "Div" || op == "Max" || op == "Min" || op == "Pow") { - node_reference.erase(it++); + // binaryop with scalar + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0) + { + node_reference[node.input(1)] -= 1; + } } - else + else if (op == "BatchNormalization") { - splitncnn_blob_count += it->second; - // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second); - ++it; + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; } + else if (op == "Clip") + { + if (node.input_size() == 3) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "Conv") + { + node_reference[node.input(1)] -= 1; + if (node.input_size() == 3) + { + node_reference[node.input(2)] -= 1; + } + } + else if (op == "ConvTranspose") + { + node_reference[node.input(1)] -= 1; + if (node.input_size() == 3) + { + node_reference[node.input(2)] -= 1; + } + } + else if (op == "Gemm") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 1.f); + int transA = get_node_attr_i(node, "transA", 0); + int transB = get_node_attr_i(node, "transB", 0); + + if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) + { + // InnerProduct-like A * B + C + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "GroupNorm") + { + int affine = get_node_attr_i(node, "affine", 1); + if (affine) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "InstanceNormalization") + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + else if (op == "LSTM") + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + } + else if (op == "MatMul") + { + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) + { + // InnerProduct + node_reference[node.input(1)] -= 1; + } + } + else if (op == "Pad") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } + else if (op == "PRelu") + { + node_reference[node.input(1)] -= 1; + } + else if (op == "Reshape") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } + else if (op == "Resize") + { + if (node.input_size() == 2) + { + // opset 10 + node_reference[node.input(1)] -= 1; + } + else + { + // opset 11+ + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + if (node.input_size() >= 4) + { + node_reference[node.input(3)] -= 1; + } + } + } + else if (op == "Slice") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + if (node.input_size() >= 4) + node_reference[node.input(3)] -= 1; + if (node.input_size() >= 5) + node_reference[node.input(4)] -= 1; + } + } + else if (op == "Upsample") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } + } + + // for (auto a: node_reference) + // { + // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second); + // } + + // count all weight node with zero reference + int zero_reference_weight_node_count = 0; + for (std::map::iterator it = weights.begin(); it != weights.end(); it++) + { + const std::string& input_name = it->first; + + int refcount = node_reference[input_name]; + if (refcount == 0) + zero_reference_weight_node_count++; } // we always treat constant node as weight or binaryop_weights @@ -1996,7 +2149,23 @@ int main(int argc, char** argv) } } - fprintf(pp, "%zu %zu\n", node_count - reduced_node_count + input_node_count + node_reference.size() + binaryop_weights.size() - reduced_binaryop_weights.size() - constant_node_count_moved_to_weight, blob_names.size() - reduced_binaryop_weights.size() + splitncnn_blob_count); + // remove node_reference entry with reference equals to one + int split_layer_count = 0; + int splitncnn_blob_count = 0; + // split node reference + std::map split_node_reference; + for (std::map::iterator it = node_reference.begin(); it != node_reference.end(); it++) + { + if (it->second > 1) + { + split_layer_count++; + splitncnn_blob_count += it->second; + + split_node_reference[it->first] = it->second; + } + } + + fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count); int internal_split = 0; @@ -2009,18 +2178,8 @@ int main(int argc, char** argv) if (weights.find(input_name) != weights.end()) continue; - // check weight before BinaryOp - if (binaryop_weights.find(input_name) != binaryop_weights.end()) - continue; - fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str()); - // split the input - if (node_reference.find(input_name) == node_reference.end()) - { - continue; - } - int refcount = node_reference[input_name]; if (refcount <= 1) { @@ -2040,16 +2199,19 @@ int main(int argc, char** argv) } // place MemoryData next - for (std::map::iterator weight_it = binaryop_weights.begin(); weight_it != binaryop_weights.end(); weight_it++) + for (std::map::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++) { const std::string& input_name = weight_it->first; - if (std::find(reduced_binaryop_weights.begin(), reduced_binaryop_weights.end(), input_name) != reduced_binaryop_weights.end()) + int refcount = node_reference[input_name]; + if (refcount == 0) + { continue; + } fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str()); - const onnx::TensorProto& M = binaryop_weights[input_name]; + const onnx::TensorProto& M = weights[input_name]; if (M.dims_size() == 0) { @@ -2079,13 +2241,6 @@ int main(int argc, char** argv) fwrite_tensor_proto_data(M, bp); - // split the input - if (node_reference.find(input_name) == node_reference.end()) - { - continue; - } - - int refcount = node_reference[input_name]; if (refcount <= 1) { continue; @@ -2133,7 +2288,7 @@ int main(int argc, char** argv) const std::string& input_name = node.input(j); // check weight - if (weights.find(input_name) != weights.end()) + if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) { input_size--; } @@ -2378,16 +2533,6 @@ int main(int argc, char** argv) } else if (op == "Reshape") { - if (node.input_size() == 1 || node.input_size() == 2) - { - const std::string& input_name = node.input(0); - - // skip weight reshape - if (weights.find(input_name) != weights.end()) - { - continue; - } - } fprintf(pp, "%-16s", "Reshape"); } else if (op == "ShuffleChannel") @@ -2472,15 +2617,15 @@ int main(int argc, char** argv) std::string input_name = node.input(j); // check weight - if (weights.find(input_name) != weights.end()) + if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) { continue; } - if (node_reference.find(input_name) != node_reference.end()) + if (split_node_reference.find(input_name) != split_node_reference.end()) { - int refidx = node_reference[input_name] - 1; - node_reference[input_name] = refidx; + int refidx = split_node_reference[input_name] - 1; + split_node_reference[input_name] = refidx; char splitsuffix[256]; sprintf(splitsuffix, "_splitncnn_%d", refidx); @@ -2511,6 +2656,13 @@ int main(int argc, char** argv) { int op_type = 0; fprintf(pp, " 0=%d", op_type); + + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0) + { + float b = get_node_attr_from_input_f(weights[node.input(1)]); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=%e", b); + } } else if (op == "Asin") { @@ -2639,11 +2791,9 @@ int main(int argc, char** argv) { const onnx::TensorProto& min_tp = weights[node.input(1)]; const onnx::TensorProto& max_tp = weights[node.input(2)]; - const float* min_data = min_tp.has_raw_data() ? (const float*)min_tp.raw_data().data() : min_tp.float_data().data(); - const float* max_data = max_tp.has_raw_data() ? (const float*)max_tp.raw_data().data() : max_tp.float_data().data(); - min = min_data[0]; - max = max_data[0]; + min = get_node_attr_from_input_f(min_tp); + max = get_node_attr_from_input_f(max_tp); } fprintf(pp, " 0=%e", min); @@ -2923,6 +3073,13 @@ int main(int argc, char** argv) { int op_type = 3; fprintf(pp, " 0=%d", op_type); + + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0) + { + float b = get_node_attr_from_input_f(weights[node.input(1)]); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=%e", b); + } } else if (op == "Dropout") { @@ -3315,16 +3472,37 @@ int main(int argc, char** argv) { int op_type = 4; fprintf(pp, " 0=%d", op_type); + + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0) + { + float b = get_node_attr_from_input_f(weights[node.input(1)]); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=%e", b); + } } else if (op == "Min") { int op_type = 5; fprintf(pp, " 0=%d", op_type); + + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0) + { + float b = get_node_attr_from_input_f(weights[node.input(1)]); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=%e"); + } } else if (op == "Mul") { int op_type = 2; fprintf(pp, " 0=%d", op_type); + + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0) + { + float b = get_node_attr_from_input_f(weights[node.input(1)]); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=%e", b); + } } else if (op == "Neg") { @@ -3418,6 +3596,13 @@ int main(int argc, char** argv) { int op_type = 6; fprintf(pp, " 0=%d", op_type); + + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0) + { + float b = get_node_attr_from_input_f(weights[node.input(1)]); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=%e", b); + } } else if (op == "PixelShuffle") { @@ -3769,6 +3954,13 @@ int main(int argc, char** argv) { int op_type = 1; fprintf(pp, " 0=%d", op_type); + + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0) + { + float b = get_node_attr_from_input_f(weights[node.input(1)]); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=%e", b); + } } else if (op == "Sum") {