提交 715f40a4 编写于 作者: D Dmitry Kurtaev

Use layers consumers to predict data layout

上级 70d6b877
......@@ -18,6 +18,7 @@ Implementation of Tensorflow models parser
#include <fstream>
#include <algorithm>
#include <string>
#include <queue>
#include "tf_graph_simplifier.hpp"
#endif
......@@ -558,9 +559,7 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
}
}
// If all inputs of specific layer have the same data layout we can say that
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map<String, int>& data_layouts)
static int getDataLayout(const tensorflow::NodeDef& layer)
{
if (hasLayerAttr(layer, "data_format"))
{
......@@ -572,27 +571,48 @@ static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::
else
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
}
return DATA_LAYOUT_UNKNOWN;
}
static inline std::string getNodeName(const std::string& tensorName)
{
return tensorName.substr(0, tensorName.rfind(':'));
}
// If all inputs of specific layer have the same data layout we can say that
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
static int predictOutputDataLayout(const tensorflow::GraphDef& net,
const tensorflow::NodeDef& layer,
const std::map<String, int>& data_layouts)
{
int layout = getDataLayout(layer);
if (layout != DATA_LAYOUT_UNKNOWN)
return layout;
// Determine layout by layer's inputs
int layout = DATA_LAYOUT_UNKNOWN;
std::map<String, int>::const_iterator it;
for (int i = 0, n = layer.input_size(); i < n; ++i)
{
it = data_layouts.find(layer.input(i).substr(0, layer.input(i).rfind(':')));
it = data_layouts.find(getNodeName(layer.input(i)));
if (it != data_layouts.end())
{
if (it->second == DATA_LAYOUT_UNKNOWN)
return DATA_LAYOUT_UNKNOWN;
else if (it->second != layout)
if (layout != DATA_LAYOUT_UNKNOWN)
{
if (layout == DATA_LAYOUT_UNKNOWN)
layout = it->second;
else
if (it->second != layout && it->second != DATA_LAYOUT_UNKNOWN)
return DATA_LAYOUT_UNKNOWN;
}
else
layout = it->second;
}
}
return layout;
if (layout != DATA_LAYOUT_UNKNOWN)
return layout;
// Determine layout by layer's consumers recursively.
it = data_layouts.find(layer.name());
CV_Assert(it != data_layouts.end());
return it->second;
}
void TFImporter::populateNet(Net dstNet)
......@@ -610,6 +630,52 @@ void TFImporter::populateNet(Net dstNet)
int layersSize = net.node_size();
std::map<String, int> data_layouts;
// Pre-fill data layouts where they are set explicitly.
// Assuming that nodes are in topological order
for (int i = net.node_size() - 1; i >= 0; --i)
{
const tensorflow::NodeDef& layer = net.node(i);
std::string name = layer.name();
int layout = getDataLayout(layer);
std::map<String, int>::iterator it = data_layouts.find(name);
if (it != data_layouts.end())
{
if (layout != DATA_LAYOUT_UNKNOWN)
{
if (it->second == DATA_LAYOUT_UNKNOWN)
it->second = layout;
else if (it->second != layout)
{
it->second = DATA_LAYOUT_UNKNOWN;
layout = DATA_LAYOUT_UNKNOWN;
}
}
else
layout = it->second;
}
else
data_layouts[name] = layout;
// Specify input layers to have the same data layout.
for (int j = 0; j < layer.input_size(); ++j)
{
name = getNodeName(layer.input(j));
it = data_layouts.find(name);
if (it != data_layouts.end())
{
if (layout != DATA_LAYOUT_UNKNOWN)
{
if (it->second == DATA_LAYOUT_UNKNOWN)
it->second = layout;
else if (it->second != layout)
it->second = DATA_LAYOUT_UNKNOWN;
}
}
else
data_layouts[name] = layout;
}
}
// find all Const layers for params
std::map<String, int> value_id;
......@@ -628,7 +694,8 @@ void TFImporter::populateNet(Net dstNet)
if(layers_to_ignore.find(name) != layers_to_ignore.end())
continue;
data_layouts[name] = predictOutputDataLayout(layer, data_layouts);
int predictedLayout = predictOutputDataLayout(net, layer, data_layouts);
data_layouts[name] = predictedLayout;
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
{
......@@ -885,6 +952,7 @@ void TFImporter::populateNet(Net dstNet)
// one input only
connect(layer_id, dstNet, inpId, id, 0);
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
else if (type == "Flatten" || type == "Squeeze")
{
......@@ -1013,7 +1081,10 @@ void TFImporter::populateNet(Net dstNet)
{
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW(axis) : axis);
if (data_layouts[name] == DATA_LAYOUT_NHWC)
axis = toNCHW(axis);
layerParams.set("axis", axis);
int id = dstNet.addLayer(name, "Concat", layerParams);
layer_id[name] = id;
......
......@@ -142,9 +142,10 @@ TEST_P(Test_TensorFlow_layers, eltwise_add_mul)
runTensorFlowNet("eltwise_add_mul", GetParam());
}
TEST_P(Test_TensorFlow_layers, pad_and_concat)
TEST_P(Test_TensorFlow_layers, concat)
{
runTensorFlowNet("pad_and_concat", GetParam());
runTensorFlowNet("concat_axis_1", GetParam());
}
TEST_P(Test_TensorFlow_layers, batch_norm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册