提交 9699e2b4 编写于 作者: A Alexander Alekhin

dnn(onnx): handle non-default ONNX domains

- re-enable quantized models tests
上级 3048188b
......@@ -15,6 +15,9 @@
#define CV_LOG_STRIP_LEVEL CV_LOG_LEVEL_VERBOSE + 1
#include <opencv2/core/utils/logger.hpp>
#include <opencv2/core/utils/configuration.private.hpp>
#ifdef HAVE_PROTOBUF
#include <iostream>
......@@ -23,6 +26,10 @@
#include <limits>
#include <algorithm>
#if defined _MSC_VER && _MSC_VER < 1910/*MSVS 2017*/
#pragma warning(push)
#pragma warning(disable: 4503) // decorated name length exceeded, name was truncated
#endif
#if defined(__GNUC__) && __GNUC__ >= 5
#pragma GCC diagnostic push
......@@ -41,8 +48,6 @@ CV__DNN_INLINE_NS_BEGIN
extern bool DNN_DIAGNOSTICS_RUN;
class ONNXLayerHandler;
class ONNXImporter
{
opencv_onnx::ModelProto model_proto;
......@@ -75,7 +80,7 @@ public:
void populateNet();
protected:
std::unique_ptr<ONNXLayerHandler> layerHandler;
std::unique_ptr<detail::LayerHandler> missingLayerHandler;
Net& dstNet;
opencv_onnx::GraphProto graph_proto;
......@@ -93,13 +98,16 @@ protected:
void handleNode(const opencv_onnx::NodeProto& node_proto);
private:
friend class ONNXLayerHandler;
typedef void (ONNXImporter::*ONNXImporterNodeParser)(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
typedef std::map<std::string, ONNXImporterNodeParser> DispatchMap;
typedef std::map<std::string, DispatchMap> DomainDispatchMap;
const DispatchMap dispatch;
static const DispatchMap buildDispatchMap();
DomainDispatchMap domain_dispatch_map;
void buildDispatchMap_ONNX_AI(int opset_version);
void buildDispatchMap_COM_MICROSOFT(int opset_version);
// Domain: 'ai.onnx' (default)
// URL: https://github.com/onnx/onnx/blob/master/docs/Operators.md
void parseArg (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMaxUnpool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
......@@ -148,6 +156,9 @@ private:
void parseSoftMax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseDetectionOutput (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCumSum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
// Domain: com.microsoft
// URL: https://github.com/microsoft/onnxruntime/blob/master/docs/ContribOperators.md
void parseQuantDequant (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseQConv (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseQMatMul (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
......@@ -157,43 +168,20 @@ private:
void parseQAvgPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseQConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
// '???' domain or '???' layer type
void parseCustomLayer (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
int onnx_opset; // OperatorSetIdProto for 'onnx' domain
std::map<std::string, int> onnx_opset_map; // map from OperatorSetIdProto
void parseOperatorSet();
};
class ONNXLayerHandler : public detail::LayerHandler
{
public:
explicit ONNXLayerHandler(ONNXImporter* importer_);
void fillRegistry(const opencv_onnx::GraphProto& net);
protected:
ONNXImporter* importer;
const std::string str_domain_ai_onnx = "ai.onnx";
};
ONNXLayerHandler::ONNXLayerHandler(ONNXImporter* importer_) : importer(importer_){}
void ONNXLayerHandler::fillRegistry(const opencv_onnx::GraphProto &net)
{
int layersSize = net.node_size();
for (int li = 0; li < layersSize; li++) {
const opencv_onnx::NodeProto &node_proto = net.node(li);
const std::string& name = node_proto.output(0);
const std::string& type = node_proto.op_type();
if (importer->dispatch.find(type) == importer->dispatch.end())
{
addMissing(name, type);
}
}
printMissing();
}
ONNXImporter::ONNXImporter(Net& net, const char *onnxFile)
: layerHandler(DNN_DIAGNOSTICS_RUN ? new ONNXLayerHandler(this) : nullptr)
, dstNet(net), dispatch(buildDispatchMap())
: missingLayerHandler(DNN_DIAGNOSTICS_RUN ? new detail::LayerHandler() : nullptr)
, dstNet(net)
, onnx_opset(0)
{
hasDynamicShapes = false;
......@@ -215,8 +203,8 @@ ONNXImporter::ONNXImporter(Net& net, const char *onnxFile)
}
ONNXImporter::ONNXImporter(Net& net, const char* buffer, size_t sizeBuffer)
: layerHandler(DNN_DIAGNOSTICS_RUN ? new ONNXLayerHandler(this) : nullptr)
, dstNet(net), dispatch(buildDispatchMap())
: missingLayerHandler(DNN_DIAGNOSTICS_RUN ? new detail::LayerHandler() : nullptr)
, dstNet(net)
, onnx_opset(0)
{
hasDynamicShapes = false;
......@@ -638,20 +626,37 @@ void ONNXImporter::parseOperatorSet()
const ::opencv_onnx::OperatorSetIdProto& opset_entry = model_proto.opset_import(i);
const std::string& domain = opset_entry.has_domain() ? opset_entry.domain() : std::string();
int version = opset_entry.has_version() ? opset_entry.version() : -1;
if (domain.empty() || domain == "ai.onnx")
if (domain.empty() || domain == str_domain_ai_onnx)
{
// ONNX opset covered by specification: https://github.com/onnx/onnx/blob/master/docs/Operators.md
onnx_opset = std::max(onnx_opset, version);
onnx_opset_map[str_domain_ai_onnx] = onnx_opset;
}
else
{
// OpenCV don't know other opsets
// will fail later on unsupported node processing
CV_LOG_WARNING(NULL, "DNN/ONNX: unsupported opset[" << i << "]: domain='" << domain << "' version=" << version);
CV_LOG_DEBUG(NULL, "DNN/ONNX: using non-standard ONNX opset[" << i << "]: domain='" << domain << "' version=" << version);
onnx_opset_map[domain] = onnx_opset;
}
}
CV_LOG_INFO(NULL, "DNN/ONNX: ONNX opset version = " << onnx_opset);
buildDispatchMap_ONNX_AI(onnx_opset);
for (const auto& pair : onnx_opset_map)
{
if (pair.first == str_domain_ai_onnx)
{
continue; // done above
}
else if (pair.first == "com.microsoft")
{
buildDispatchMap_COM_MICROSOFT(pair.second);
}
else
{
CV_LOG_INFO(NULL, "DNN/ONNX: unknown domain='" << pair.first << "' version=" << pair.second << ". No dispatch map, you may need to register 'custom' layers.");
}
}
}
void ONNXImporter::handleQuantizedNode(LayerParams& layerParams,
......@@ -790,7 +795,6 @@ void ONNXImporter::populateNet()
if (DNN_DIAGNOSTICS_RUN) {
CV_LOG_INFO(NULL, "DNN/ONNX: start diagnostic run!");
layerHandler->fillRegistry(graph_proto);
}
for(int li = 0; li < layersSize; li++)
......@@ -805,22 +809,52 @@ void ONNXImporter::populateNet()
void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.output_size() >= 1);
std::string name = node_proto.output(0);
const std::string& name = node_proto.output(0);
const std::string& layer_type = node_proto.op_type();
const std::string& layer_type_domain = node_proto.has_domain() ? node_proto.domain() : std::string();
if (!layer_type_domain.empty() && layer_type_domain != "ai.onnx")
const std::string& layer_type_domain = [&]()
{
CV_LOG_WARNING(NULL, "DNN/ONNX: can't handle node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s@%s]:(%s)", layer_type.c_str(), layer_type_domain.c_str(), name.c_str())
);
if (DNN_DIAGNOSTICS_RUN)
return; // ignore error
CV_Error(Error::StsNotImplemented, cv::format("ONNX: unsupported domain: %s", layer_type_domain.c_str()));
}
if (!node_proto.has_domain())
return str_domain_ai_onnx;
const std::string& domain = node_proto.domain();
if (domain.empty())
return str_domain_ai_onnx;
return domain;
}();
const auto& dispatch = [&]()
{
if (layer_type_domain != str_domain_ai_onnx)
{
if (onnx_opset_map.find(layer_type_domain) == onnx_opset_map.end())
{
CV_LOG_WARNING(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
<< " from undeclared domain='" << layer_type_domain << "'"
);
}
else
{
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
<< " from domain='" << layer_type_domain << "'"
);
}
auto it = domain_dispatch_map.find(layer_type_domain);
if (it == domain_dispatch_map.end())
{
CV_LOG_WARNING(NULL, "DNN/ONNX: missing dispatch map for domain='" << layer_type_domain << "'");
return DispatchMap();
}
return it->second;
}
else
{
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
);
return domain_dispatch_map[str_domain_ai_onnx];
}
}();
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
);
LayerParams layerParams;
try
{
......@@ -848,7 +882,9 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
if (DNN_DIAGNOSTICS_RUN)
{
CV_LOG_ERROR(NULL, "DNN/ONNX: Potential problem during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str()) << "\n" << e.msg
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
<< " from domain='" << layer_type_domain << "'"
<< "\n" << e.msg
);
cv::AutoLock lock(getLayerFactoryMutex());
auto registeredLayers = getLayerFactoryImpl();
......@@ -869,6 +905,7 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
{
CV_LOG_ERROR(NULL, "DNN/ONNX: ERROR during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
<< " from domain='" << layer_type_domain << "'"
);
}
for (int i = 0; i < node_proto.input_size(); i++)
......@@ -888,7 +925,7 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
}
}
else
CV_Error(Error::StsError, cv::format("Node [%s]:(%s) parse error: %s", layer_type.c_str(), name.c_str(), e.what()));
CV_Error(Error::StsError, cv::format("Node [%s@%s]:(%s) parse error: %s", layer_type.c_str(), layer_type_domain.c_str(), name.c_str(), e.what()));
}
}
......@@ -2836,6 +2873,28 @@ void ONNXImporter::parseCumSum(LayerParams& layerParams, const opencv_onnx::Node
void ONNXImporter::parseCustomLayer(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
const std::string& name = layerParams.name;
std::string& layer_type = layerParams.type;
const std::string& layer_type_domain = node_proto.has_domain() ? node_proto.domain() : std::string();
if (!layer_type_domain.empty() && layer_type_domain != str_domain_ai_onnx)
{
// append ONNX domain name
static bool DNN_CUSTOM_ONNX_TYPE_INCLUDE_DOMAIN_NAME = utils::getConfigurationParameterBool("OPENCV_DNN_CUSTOM_ONNX_TYPE_INCLUDE_DOMAIN_NAME", true);
if (DNN_CUSTOM_ONNX_TYPE_INCLUDE_DOMAIN_NAME)
{
layer_type = layer_type_domain + "." + layer_type;
}
}
CV_LOG_INFO(NULL, "DNN/ONNX: unknown node type, try using custom handler for node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
);
if (missingLayerHandler)
{
missingLayerHandler->addMissing(layerParams.name, layerParams.type);
}
for (int j = 0; j < node_proto.input_size(); j++) {
if (layer_id.find(node_proto.input(j)) == layer_id.end())
layerParams.blobs.push_back(getBlob(node_proto, j));
......@@ -3233,8 +3292,11 @@ void ONNXImporter::parseQConcat(LayerParams& layerParams, const opencv_onnx::Nod
addLayer(layerParams, node_proto);
}
const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
// Domain: ai.onnx (default)
// URL: https://github.com/onnx/onnx/blob/master/docs/Operators.md
void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
{
CV_UNUSED(opset_version);
DispatchMap dispatch;
dispatch["ArgMax"] = dispatch["ArgMin"] = &ONNXImporter::parseArg;
......@@ -3286,18 +3348,32 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
dispatch["SoftMax"] = dispatch["LogSoftmax"] = &ONNXImporter::parseSoftMax;
dispatch["DetectionOutput"] = &ONNXImporter::parseDetectionOutput;
dispatch["CumSum"] = &ONNXImporter::parseCumSum;
// ai.onnx: opset 10+
dispatch["QuantizeLinear"] = dispatch["DequantizeLinear"] = &ONNXImporter::parseQuantDequant;
dispatch["QLinearConv"] = &ONNXImporter::parseQConv;
dispatch["QLinearMatMul"] = &ONNXImporter::parseQMatMul;
domain_dispatch_map[str_domain_ai_onnx] = dispatch;
}
// Domain: com.microsoft
// URL: https://github.com/microsoft/onnxruntime/blob/master/docs/ContribOperators.md
void ONNXImporter::buildDispatchMap_COM_MICROSOFT(int opset_version)
{
CV_UNUSED(opset_version);
DispatchMap dispatch;
dispatch["QLinearAdd"] = dispatch["QLinearMul"] = &ONNXImporter::parseQEltwise;
dispatch["QLinearAveragePool"] = dispatch["QLinearGlobalAveragePool"] = &ONNXImporter::parseQAvgPool;
dispatch["QLinearLeakyRelu"] = &ONNXImporter::parseQLeakyRelu;
dispatch["QLinearSigmoid"] = &ONNXImporter::parseQSigmoid;
dispatch["QLinearAveragePool"] = dispatch["QLinearGlobalAveragePool"] = &ONNXImporter::parseQAvgPool;
dispatch["QLinearConcat"] = &ONNXImporter::parseQConcat;
return dispatch;
domain_dispatch_map["com.microsoft"] = dispatch;
}
Net readNetFromONNX(const String& onnxFile)
{
return detail::readNetDiagnostic<ONNXImporter>(onnxFile.c_str());
......
......@@ -1489,16 +1489,6 @@ TEST_P(Test_ONNX_layers, DivConst)
}
// FIXIT disabled due to non-standard ONNX model domains, need to add ONNX domains support
// Example:
// DNN/ONNX: unsupported opset[1]: domain='com.microsoft.experimental' version=1
// DNN/ONNX: unsupported opset[2]: domain='ai.onnx.preview.training' version=1
// DNN/ONNX: unsupported opset[3]: domain='com.microsoft.nchwc' version=1
// DNN/ONNX: unsupported opset[4]: domain='com.microsoft.mlfeaturizers' version=1
// DNN/ONNX: unsupported opset[5]: domain='ai.onnx.ml' version=2
// DNN/ONNX: unsupported opset[6]: domain='com.microsoft' version=1
// DNN/ONNX: unsupported opset[7]: domain='ai.onnx.training' version=1
#if 0
TEST_P(Test_ONNX_layers, Quantized_Convolution)
{
testONNXModels("quantized_conv_uint8_weights", npy, 0.004, 0.02);
......@@ -1604,7 +1594,6 @@ TEST_P(Test_ONNX_layers, Quantized_Constant)
{
testONNXModels("quantized_constant", npy, 0.002, 0.008);
}
#endif
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
......@@ -1749,8 +1738,7 @@ TEST_P(Test_ONNX_nets, ResNet50v1)
testONNXModels("resnet50v1", pb, default_l1, default_lInf, true, target != DNN_TARGET_MYRIAD);
}
// FIXIT missing ONNX domains support
TEST_P(Test_ONNX_nets, DISABLED_ResNet50_Int8)
TEST_P(Test_ONNX_nets, ResNet50_Int8)
{
testONNXModels("resnet50_int8", pb, default_l1, default_lInf, true);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册