提交 7a2b3ed4 编写于 作者: A Anastasia Murzova

Corrected DNN elementwise multiplication

上级 551d4a8e
...@@ -12,6 +12,7 @@ Implementation of Tensorflow models parser ...@@ -12,6 +12,7 @@ Implementation of Tensorflow models parser
#include "../precomp.hpp" #include "../precomp.hpp"
#include <opencv2/core/utils/logger.defines.hpp> #include <opencv2/core/utils/logger.defines.hpp>
#include <opencv2/dnn/shape_utils.hpp>
#undef CV_LOG_STRIP_LEVEL #undef CV_LOG_STRIP_LEVEL
#define CV_LOG_STRIP_LEVEL CV_LOG_LEVEL_DEBUG + 1 #define CV_LOG_STRIP_LEVEL CV_LOG_LEVEL_DEBUG + 1
#include <opencv2/core/utils/logger.hpp> #include <opencv2/core/utils/logger.hpp>
...@@ -1825,6 +1826,7 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_) ...@@ -1825,6 +1826,7 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
{ {
// Check if all the inputs have the same shape. // Check if all the inputs have the same shape.
bool equalInpShapes = true; bool equalInpShapes = true;
bool isShapeOnes = false;
MatShape outShape0; MatShape outShape0;
for (int ii = 0; ii < num_inputs && !netInputShapes.empty(); ii++) for (int ii = 0; ii < num_inputs && !netInputShapes.empty(); ii++)
{ {
...@@ -1845,12 +1847,14 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_) ...@@ -1845,12 +1847,14 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
else if (outShape != outShape0) else if (outShape != outShape0)
{ {
equalInpShapes = false; equalInpShapes = false;
isShapeOnes = isAllOnes(outShape, 2, outShape.size()) ||
isAllOnes(outShape0, 2, outShape0.size());
break; break;
} }
} }
int id; int id;
if (equalInpShapes || netInputShapes.empty()) if (equalInpShapes || netInputShapes.empty() || (!equalInpShapes && isShapeOnes))
{ {
layerParams.set("operation", type == "RealDiv" ? "div" : "prod"); layerParams.set("operation", type == "RealDiv" ? "div" : "prod");
id = dstNet.addLayer(name, "Eltwise", layerParams); id = dstNet.addLayer(name, "Eltwise", layerParams);
......
...@@ -210,6 +210,12 @@ TEST_P(Test_TensorFlow_layers, eltwise_add_vec) ...@@ -210,6 +210,12 @@ TEST_P(Test_TensorFlow_layers, eltwise_add_vec)
runTensorFlowNet("eltwise_add_vec"); runTensorFlowNet("eltwise_add_vec");
} }
TEST_P(Test_TensorFlow_layers, eltwise_mul_vec)
{
runTensorFlowNet("eltwise_mul_vec");
}
TEST_P(Test_TensorFlow_layers, channel_broadcast) TEST_P(Test_TensorFlow_layers, channel_broadcast)
{ {
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册