提交 2506eee8 编写于 作者: J jiangjiajun

fix identity problem

上级 c5b22aca
......@@ -31,12 +31,22 @@ class PaddleEmitter(object):
self.weights = parser.weights
self.infer = parser.infer
self.inputs_sample_data = dict()
self.outputs = parser.outputs
self.inputs = parser.inputs
self.save_dir = save_dir
self.body_code = ""
self.tab = " " * 4
self.outputs = parser.outputs
self.inputs = parser.inputs
outputs = list()
for output in self.outputs:
while True:
if output in self.graph.identity_relation:
output = self.graph.identity_relation[output]
else:
break
outputs.append(output)
self.outputs = outputs
@staticmethod
def compute_padding_size(in_size, filter_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride))
......@@ -306,10 +316,10 @@ class PaddleEmitter(object):
node.code.add_str("#{} {} {}".format(node.layer_name, node.ref_name,
value.shape))
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 2):
if value.size == 1:
param_attr = {
"shape": [1],
"value": value,
"value": value.flatten()[0],
"dtype": "\'{}\'".format(dtype),
}
node.code.add_layer("fill_constant", None, node.output_name,
......@@ -343,9 +353,6 @@ class PaddleEmitter(object):
if node.data_format == NHWC:
input_h, input_w = input_shape[1:3]
strides = node.get_attr("strides")[1:3]
if k_h < strides[0] or k_w < strides[1]:
raise Exception("Unexpected situation with kernel's height/width " \
"less than the corresponding stride")
if kernel.layer_name in self.weights:
weight = self.weights[kernel.layer_name]
......@@ -1035,3 +1042,11 @@ class PaddleEmitter(object):
num_sections) + num_sections[index[0]]
param = {"num_or_sections": list(num_sections), "dim": split_dim}
node.code.add_layer("split", data.ref_name, node.output_name, param)
def emit_expanddims(self, node):
data = node.inputs[0]
dim = node.inputs[1]
dim.code.clear()
dim = self.infer.get_const_tensor_value(dim.layer)
param = {"axes":[dim]}
node.code.add_layer("unsqueeze", data.ref_name, node.output_name, param)
......@@ -76,6 +76,7 @@ class TensorflowGraph(Graph):
def __init__(self, tf_graph):
super(TensorflowGraph, self).__init__(tf_graph)
self.tf_graph = tf_graph
self.identity_relation = dict()
def build(self, input_format):
skip_node = set(['const'])
......@@ -131,6 +132,7 @@ class TensorflowGraph(Graph):
current_node = self.node_map[name]
if current_node.layer_type in self.useless_type:
input = current_node.inputs[0]
self.identity_relation[current_node.layer.name] = input.layer.name
for node in current_node.outputs:
for k in range(0, len(node.inputs)):
if node.inputs[k] == current_node:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册