提交 3478240d 编写于 作者: A Alexander Alekhin

Merge pull request #13656 from dkurt:dnn_tf_atrous_faster_rcnn

......@@ -48,10 +48,42 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
removeIdentity(graph_def)
nodesToKeep = []
def to_remove(name, op):
if name in nodesToKeep:
return False
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
(name.startswith('CropAndResize') and op != 'CropAndResize')
# Fuse atrous convolutions (with dilations).
nodesMap = {node.name: node for node in graph_def.node}
for node in reversed(graph_def.node):
if node.op == 'BatchToSpaceND':
del node.input[2]
conv = nodesMap[node.input[0]]
spaceToBatchND = nodesMap[conv.input[0]]
# Extract paddings
stridedSlice = nodesMap[spaceToBatchND.input[2]]
assert(stridedSlice.op == 'StridedSlice')
pack = nodesMap[stridedSlice.input[0]]
assert(pack.op == 'Pack')
padNodeH = nodesMap[nodesMap[pack.input[0]].input[0]]
padNodeW = nodesMap[nodesMap[pack.input[1]].input[0]]
padH = int(padNodeH.attr['value']['tensor'][0]['int_val'][0])
padW = int(padNodeW.attr['value']['tensor'][0]['int_val'][0])
paddingsNode = NodeDef()
paddingsNode.name = conv.name + '/paddings'
paddingsNode.op = 'Const'
paddingsNode.addAttr('value', [padH, padH, padW, padW])
graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode)
nodesToKeep.append(paddingsNode.name)
spaceToBatchND.input[2] = paddingsNode.name
removeUnusedNodesAndAttrs(to_remove, graph_def)
......@@ -225,6 +257,26 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
detectionOut.addAttr('variance_encoded_in_target', True)
graph_def.node.extend([detectionOut])
def getUnconnectedNodes():
unconnected = [node.name for node in graph_def.node]
for node in graph_def.node:
for inp in node.input:
if inp in unconnected:
unconnected.remove(inp)
return unconnected
while True:
unconnectedNodes = getUnconnectedNodes()
unconnectedNodes.remove(detectionOut.name)
if not unconnectedNodes:
break
for name in unconnectedNodes:
for i in range(len(graph_def.node)):
if graph_def.node[i].name == name:
del graph_def.node[i]
break
# Save as text.
graph_def.save(outputPath)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册