未验证 提交 3d687c6a 编写于 作者: J Jason 提交者: GitHub

Merge pull request #5 from jiangjiajun/master

fix identity problem
...@@ -31,12 +31,22 @@ class PaddleEmitter(object): ...@@ -31,12 +31,22 @@ class PaddleEmitter(object):
self.weights = parser.weights self.weights = parser.weights
self.infer = parser.infer self.infer = parser.infer
self.inputs_sample_data = dict() self.inputs_sample_data = dict()
self.outputs = parser.outputs
self.inputs = parser.inputs
self.save_dir = save_dir self.save_dir = save_dir
self.body_code = "" self.body_code = ""
self.tab = " " * 4 self.tab = " " * 4
self.outputs = parser.outputs
self.inputs = parser.inputs
outputs = list()
for output in self.outputs:
while True:
if output in self.graph.identity_relation:
output = self.graph.identity_relation[output]
else:
break
outputs.append(output)
self.outputs = outputs
@staticmethod @staticmethod
def compute_padding_size(in_size, filter_size, stride): def compute_padding_size(in_size, filter_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride)) new_size = int(math.ceil(in_size * 1.0 / stride))
...@@ -306,10 +316,10 @@ class PaddleEmitter(object): ...@@ -306,10 +316,10 @@ class PaddleEmitter(object):
node.code.add_str("#{} {} {}".format(node.layer_name, node.ref_name, node.code.add_str("#{} {} {}".format(node.layer_name, node.ref_name,
value.shape)) value.shape))
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 2): if value.size == 1:
param_attr = { param_attr = {
"shape": [1], "shape": [1],
"value": value, "value": value.flatten()[0],
"dtype": "\'{}\'".format(dtype), "dtype": "\'{}\'".format(dtype),
} }
node.code.add_layer("fill_constant", None, node.output_name, node.code.add_layer("fill_constant", None, node.output_name,
...@@ -343,9 +353,6 @@ class PaddleEmitter(object): ...@@ -343,9 +353,6 @@ class PaddleEmitter(object):
if node.data_format == NHWC: if node.data_format == NHWC:
input_h, input_w = input_shape[1:3] input_h, input_w = input_shape[1:3]
strides = node.get_attr("strides")[1:3] strides = node.get_attr("strides")[1:3]
if k_h < strides[0] or k_w < strides[1]:
raise Exception("Unexpected situation with kernel's height/width " \
"less than the corresponding stride")
if kernel.layer_name in self.weights: if kernel.layer_name in self.weights:
weight = self.weights[kernel.layer_name] weight = self.weights[kernel.layer_name]
...@@ -1035,3 +1042,11 @@ class PaddleEmitter(object): ...@@ -1035,3 +1042,11 @@ class PaddleEmitter(object):
num_sections) + num_sections[index[0]] num_sections) + num_sections[index[0]]
param = {"num_or_sections": list(num_sections), "dim": split_dim} param = {"num_or_sections": list(num_sections), "dim": split_dim}
node.code.add_layer("split", data.ref_name, node.output_name, param) node.code.add_layer("split", data.ref_name, node.output_name, param)
def emit_expanddims(self, node):
data = node.inputs[0]
dim = node.inputs[1]
dim.code.clear()
dim = self.infer.get_const_tensor_value(dim.layer)
param = {"axes":[dim]}
node.code.add_layer("unsqueeze", data.ref_name, node.output_name, param)
...@@ -76,6 +76,7 @@ class TensorflowGraph(Graph): ...@@ -76,6 +76,7 @@ class TensorflowGraph(Graph):
def __init__(self, tf_graph): def __init__(self, tf_graph):
super(TensorflowGraph, self).__init__(tf_graph) super(TensorflowGraph, self).__init__(tf_graph)
self.tf_graph = tf_graph self.tf_graph = tf_graph
self.identity_relation = dict()
def build(self, input_format): def build(self, input_format):
skip_node = set(['const']) skip_node = set(['const'])
...@@ -131,6 +132,7 @@ class TensorflowGraph(Graph): ...@@ -131,6 +132,7 @@ class TensorflowGraph(Graph):
current_node = self.node_map[name] current_node = self.node_map[name]
if current_node.layer_type in self.useless_type: if current_node.layer_type in self.useless_type:
input = current_node.inputs[0] input = current_node.inputs[0]
self.identity_relation[current_node.layer.name] = input.layer.name
for node in current_node.outputs: for node in current_node.outputs:
for k in range(0, len(node.inputs)): for k in range(0, len(node.inputs)):
if node.inputs[k] == current_node: if node.inputs[k] == current_node:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册