提交 715f40a4 编写于 作者: D Dmitry Kurtaev

Use layers consumers to predict data layout

上级 70d6b877
...@@ -18,6 +18,7 @@ Implementation of Tensorflow models parser ...@@ -18,6 +18,7 @@ Implementation of Tensorflow models parser
#include <fstream> #include <fstream>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <queue>
#include "tf_graph_simplifier.hpp" #include "tf_graph_simplifier.hpp"
#endif #endif
...@@ -558,9 +559,7 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons ...@@ -558,9 +559,7 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
} }
} }
// If all inputs of specific layer have the same data layout we can say that static int getDataLayout(const tensorflow::NodeDef& layer)
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map<String, int>& data_layouts)
{ {
if (hasLayerAttr(layer, "data_format")) if (hasLayerAttr(layer, "data_format"))
{ {
...@@ -572,27 +571,48 @@ static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std:: ...@@ -572,27 +571,48 @@ static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::
else else
CV_Error(Error::StsParseError, "Unknown data_format value: " + format); CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
} }
return DATA_LAYOUT_UNKNOWN;
}
static inline std::string getNodeName(const std::string& tensorName)
{
return tensorName.substr(0, tensorName.rfind(':'));
}
// If all inputs of specific layer have the same data layout we can say that
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
static int predictOutputDataLayout(const tensorflow::GraphDef& net,
const tensorflow::NodeDef& layer,
const std::map<String, int>& data_layouts)
{
int layout = getDataLayout(layer);
if (layout != DATA_LAYOUT_UNKNOWN)
return layout;
// Determine layout by layer's inputs // Determine layout by layer's inputs
int layout = DATA_LAYOUT_UNKNOWN;
std::map<String, int>::const_iterator it; std::map<String, int>::const_iterator it;
for (int i = 0, n = layer.input_size(); i < n; ++i) for (int i = 0, n = layer.input_size(); i < n; ++i)
{ {
it = data_layouts.find(layer.input(i).substr(0, layer.input(i).rfind(':'))); it = data_layouts.find(getNodeName(layer.input(i)));
if (it != data_layouts.end()) if (it != data_layouts.end())
{ {
if (it->second == DATA_LAYOUT_UNKNOWN) if (layout != DATA_LAYOUT_UNKNOWN)
return DATA_LAYOUT_UNKNOWN;
else if (it->second != layout)
{ {
if (layout == DATA_LAYOUT_UNKNOWN) if (it->second != layout && it->second != DATA_LAYOUT_UNKNOWN)
layout = it->second;
else
return DATA_LAYOUT_UNKNOWN; return DATA_LAYOUT_UNKNOWN;
} }
else
layout = it->second;
} }
} }
return layout;
if (layout != DATA_LAYOUT_UNKNOWN)
return layout;
// Determine layout by layer's consumers recursively.
it = data_layouts.find(layer.name());
CV_Assert(it != data_layouts.end());
return it->second;
} }
void TFImporter::populateNet(Net dstNet) void TFImporter::populateNet(Net dstNet)
...@@ -610,6 +630,52 @@ void TFImporter::populateNet(Net dstNet) ...@@ -610,6 +630,52 @@ void TFImporter::populateNet(Net dstNet)
int layersSize = net.node_size(); int layersSize = net.node_size();
std::map<String, int> data_layouts; std::map<String, int> data_layouts;
// Pre-fill data layouts where they are set explicitly.
// Assuming that nodes are in topological order
for (int i = net.node_size() - 1; i >= 0; --i)
{
const tensorflow::NodeDef& layer = net.node(i);
std::string name = layer.name();
int layout = getDataLayout(layer);
std::map<String, int>::iterator it = data_layouts.find(name);
if (it != data_layouts.end())
{
if (layout != DATA_LAYOUT_UNKNOWN)
{
if (it->second == DATA_LAYOUT_UNKNOWN)
it->second = layout;
else if (it->second != layout)
{
it->second = DATA_LAYOUT_UNKNOWN;
layout = DATA_LAYOUT_UNKNOWN;
}
}
else
layout = it->second;
}
else
data_layouts[name] = layout;
// Specify input layers to have the same data layout.
for (int j = 0; j < layer.input_size(); ++j)
{
name = getNodeName(layer.input(j));
it = data_layouts.find(name);
if (it != data_layouts.end())
{
if (layout != DATA_LAYOUT_UNKNOWN)
{
if (it->second == DATA_LAYOUT_UNKNOWN)
it->second = layout;
else if (it->second != layout)
it->second = DATA_LAYOUT_UNKNOWN;
}
}
else
data_layouts[name] = layout;
}
}
// find all Const layers for params // find all Const layers for params
std::map<String, int> value_id; std::map<String, int> value_id;
...@@ -628,7 +694,8 @@ void TFImporter::populateNet(Net dstNet) ...@@ -628,7 +694,8 @@ void TFImporter::populateNet(Net dstNet)
if(layers_to_ignore.find(name) != layers_to_ignore.end()) if(layers_to_ignore.find(name) != layers_to_ignore.end())
continue; continue;
data_layouts[name] = predictOutputDataLayout(layer, data_layouts); int predictedLayout = predictOutputDataLayout(net, layer, data_layouts);
data_layouts[name] = predictedLayout;
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative") if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
{ {
...@@ -885,6 +952,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -885,6 +952,7 @@ void TFImporter::populateNet(Net dstNet)
// one input only // one input only
connect(layer_id, dstNet, inpId, id, 0); connect(layer_id, dstNet, inpId, id, 0);
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
} }
else if (type == "Flatten" || type == "Squeeze") else if (type == "Flatten" || type == "Squeeze")
{ {
...@@ -1013,7 +1081,10 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1013,7 +1081,10 @@ void TFImporter::populateNet(Net dstNet)
{ {
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1); int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0); int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW(axis) : axis);
if (data_layouts[name] == DATA_LAYOUT_NHWC)
axis = toNCHW(axis);
layerParams.set("axis", axis);
int id = dstNet.addLayer(name, "Concat", layerParams); int id = dstNet.addLayer(name, "Concat", layerParams);
layer_id[name] = id; layer_id[name] = id;
......
...@@ -142,9 +142,10 @@ TEST_P(Test_TensorFlow_layers, eltwise_add_mul) ...@@ -142,9 +142,10 @@ TEST_P(Test_TensorFlow_layers, eltwise_add_mul)
runTensorFlowNet("eltwise_add_mul", GetParam()); runTensorFlowNet("eltwise_add_mul", GetParam());
} }
TEST_P(Test_TensorFlow_layers, pad_and_concat) TEST_P(Test_TensorFlow_layers, concat)
{ {
runTensorFlowNet("pad_and_concat", GetParam()); runTensorFlowNet("pad_and_concat", GetParam());
runTensorFlowNet("concat_axis_1", GetParam());
} }
TEST_P(Test_TensorFlow_layers, batch_norm) TEST_P(Test_TensorFlow_layers, batch_norm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册