提交 8574a757 编写于 作者: D Dmitry Kurtaev

Case sensitive dnn layers types

上级 7b82ad29
...@@ -4626,16 +4626,15 @@ void LayerFactory::registerLayer(const String &type, Constructor constructor) ...@@ -4626,16 +4626,15 @@ void LayerFactory::registerLayer(const String &type, Constructor constructor)
CV_TRACE_ARG_VALUE(type, "type", type.c_str()); CV_TRACE_ARG_VALUE(type, "type", type.c_str());
cv::AutoLock lock(getLayerFactoryMutex()); cv::AutoLock lock(getLayerFactoryMutex());
String type_ = type.toLowerCase(); LayerFactory_Impl::iterator it = getLayerFactoryImpl().find(type);
LayerFactory_Impl::iterator it = getLayerFactoryImpl().find(type_);
if (it != getLayerFactoryImpl().end()) if (it != getLayerFactoryImpl().end())
{ {
if (it->second.back() == constructor) if (it->second.back() == constructor)
CV_Error(cv::Error::StsBadArg, "Layer \"" + type_ + "\" already was registered"); CV_Error(cv::Error::StsBadArg, "Layer \"" + type + "\" already was registered");
it->second.push_back(constructor); it->second.push_back(constructor);
} }
getLayerFactoryImpl().insert(std::make_pair(type_, std::vector<Constructor>(1, constructor))); getLayerFactoryImpl().insert(std::make_pair(type, std::vector<Constructor>(1, constructor)));
} }
void LayerFactory::unregisterLayer(const String &type) void LayerFactory::unregisterLayer(const String &type)
...@@ -4644,9 +4643,8 @@ void LayerFactory::unregisterLayer(const String &type) ...@@ -4644,9 +4643,8 @@ void LayerFactory::unregisterLayer(const String &type)
CV_TRACE_ARG_VALUE(type, "type", type.c_str()); CV_TRACE_ARG_VALUE(type, "type", type.c_str());
cv::AutoLock lock(getLayerFactoryMutex()); cv::AutoLock lock(getLayerFactoryMutex());
String type_ = type.toLowerCase();
LayerFactory_Impl::iterator it = getLayerFactoryImpl().find(type_); LayerFactory_Impl::iterator it = getLayerFactoryImpl().find(type);
if (it != getLayerFactoryImpl().end()) if (it != getLayerFactoryImpl().end())
{ {
if (it->second.size() > 1) if (it->second.size() > 1)
...@@ -4662,8 +4660,7 @@ Ptr<Layer> LayerFactory::createLayerInstance(const String &type, LayerParams& pa ...@@ -4662,8 +4660,7 @@ Ptr<Layer> LayerFactory::createLayerInstance(const String &type, LayerParams& pa
CV_TRACE_ARG_VALUE(type, "type", type.c_str()); CV_TRACE_ARG_VALUE(type, "type", type.c_str());
cv::AutoLock lock(getLayerFactoryMutex()); cv::AutoLock lock(getLayerFactoryMutex());
String type_ = type.toLowerCase(); LayerFactory_Impl::const_iterator it = getLayerFactoryImpl().find(type);
LayerFactory_Impl::const_iterator it = getLayerFactoryImpl().find(type_);
if (it != getLayerFactoryImpl().end()) if (it != getLayerFactoryImpl().end())
{ {
......
...@@ -95,6 +95,7 @@ void initializeLayerFactory() ...@@ -95,6 +95,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(LRN, LRNLayer); CV_DNN_REGISTER_LAYER_CLASS(LRN, LRNLayer);
CV_DNN_REGISTER_LAYER_CLASS(InnerProduct, InnerProductLayer); CV_DNN_REGISTER_LAYER_CLASS(InnerProduct, InnerProductLayer);
CV_DNN_REGISTER_LAYER_CLASS(Softmax, SoftmaxLayer); CV_DNN_REGISTER_LAYER_CLASS(Softmax, SoftmaxLayer);
CV_DNN_REGISTER_LAYER_CLASS(SoftMax, SoftmaxLayer); // For compatibility. See https://github.com/opencv/opencv/issues/16877
CV_DNN_REGISTER_LAYER_CLASS(MVN, MVNLayer); CV_DNN_REGISTER_LAYER_CLASS(MVN, MVNLayer);
CV_DNN_REGISTER_LAYER_CLASS(ReLU, ReLULayer); CV_DNN_REGISTER_LAYER_CLASS(ReLU, ReLULayer);
......
...@@ -615,6 +615,15 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -615,6 +615,15 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.type = "ReLU"; layerParams.type = "ReLU";
replaceLayerParam(layerParams, "alpha", "negative_slope"); replaceLayerParam(layerParams, "alpha", "negative_slope");
} }
else if (layer_type == "Relu")
{
layerParams.type = "ReLU";
}
else if (layer_type == "PRelu")
{
layerParams.type = "PReLU";
layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 1));
}
else if (layer_type == "LRN") else if (layer_type == "LRN")
{ {
replaceLayerParam(layerParams, "size", "local_size"); replaceLayerParam(layerParams, "size", "local_size");
...@@ -1133,10 +1142,10 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -1133,10 +1142,10 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("zoom_factor_x", scales.at<float>(3)); layerParams.set("zoom_factor_x", scales.at<float>(3));
} }
} }
else if (layer_type == "LogSoftmax") else if (layer_type == "SoftMax" || layer_type == "LogSoftmax")
{ {
layerParams.type = "Softmax"; layerParams.type = "Softmax";
layerParams.set("log_softmax", true); layerParams.set("log_softmax", layer_type == "LogSoftmax");
} }
else else
{ {
......
...@@ -865,15 +865,10 @@ struct TorchImporter ...@@ -865,15 +865,10 @@ struct TorchImporter
layerParams.set("indices_blob_id", tensorParams["indices"].first); layerParams.set("indices_blob_id", tensorParams["indices"].first);
curModule->modules.push_back(newModule); curModule->modules.push_back(newModule);
} }
else if (nnName == "SoftMax") else if (nnName == "LogSoftMax" || nnName == "SoftMax")
{ {
newModule->apiType = "SoftMax"; newModule->apiType = "Softmax";
curModule->modules.push_back(newModule); layerParams.set("log_softmax", nnName == "LogSoftMax");
}
else if (nnName == "LogSoftMax")
{
newModule->apiType = "SoftMax";
layerParams.set("log_softmax", true);
curModule->modules.push_back(newModule); curModule->modules.push_back(newModule);
} }
else if (nnName == "SpatialCrossMapLRN") else if (nnName == "SpatialCrossMapLRN")
......
...@@ -431,7 +431,7 @@ TEST_P(SoftMax, Accuracy) ...@@ -431,7 +431,7 @@ TEST_P(SoftMax, Accuracy)
Backend backendId = get<0>(get<1>(GetParam())); Backend backendId = get<0>(get<1>(GetParam()));
Target targetId = get<1>(get<1>(GetParam())); Target targetId = get<1>(get<1>(GetParam()));
LayerParams lp; LayerParams lp;
lp.type = "SoftMax"; lp.type = "Softmax";
lp.name = "testLayer"; lp.name = "testLayer";
int sz[] = {1, inChannels, 1, 1}; int sz[] = {1, inChannels, 1, 1};
......
...@@ -70,7 +70,7 @@ public: ...@@ -70,7 +70,7 @@ public:
{ {
LayerParams lp; LayerParams lp;
Net netSoftmax; Net netSoftmax;
netSoftmax.addLayerToPrev("softmaxLayer", "SoftMax", lp); netSoftmax.addLayerToPrev("softmaxLayer", "Softmax", lp);
netSoftmax.setPreferableBackend(DNN_BACKEND_OPENCV); netSoftmax.setPreferableBackend(DNN_BACKEND_OPENCV);
netSoftmax.setInput(out); netSoftmax.setInput(out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册