提交 9cfd219d 编写于 作者: D Dmitry Kurtaev

Fix Mobilenet v2 from TensorFlow slim

上级 9340fc0c
......@@ -630,6 +630,21 @@ public:
}
};
class SoftMaxSlimSubgraph : public Subgraph
{
public:
SoftMaxSlimSubgraph()
{
int input = addNodeToMatch("");
int shape = addNodeToMatch("Const");
int shapeOp = addNodeToMatch("Shape", input);
int reshape = addNodeToMatch("Reshape", input, shape);
int softmax = addNodeToMatch("Softmax", reshape);
addNodeToMatch("Reshape", softmax, shapeOp);
setFusedNode("Softmax", input);
}
};
void simplifySubgraphs(tensorflow::GraphDef& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
......@@ -646,6 +661,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
int numNodes = net.node_size();
std::vector<int> matchedNodesIds;
......
......@@ -661,7 +661,10 @@ void TFImporter::populateNet(Net dstNet)
RemoveIdentityOps(netTxt);
if (!netTxt.ByteSize())
{
simplifySubgraphs(netBin);
sortByExecutionOrder(netBin);
}
std::set<String> layers_to_ignore;
......
......@@ -549,6 +549,7 @@ TEST_P(Test_TensorFlow_layers, slice)
TEST_P(Test_TensorFlow_layers, softmax)
{
runTensorFlowNet("keras_softmax");
runTensorFlowNet("slim_softmax");
}
TEST_P(Test_TensorFlow_layers, relu6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册