提交 bbd0ed0e 编写于 作者: S SunAhong1993

fix the bug

上级 c1142687
...@@ -224,20 +224,21 @@ class TFOpMapper(OpMapper): ...@@ -224,20 +224,21 @@ class TFOpMapper(OpMapper):
return return
self.params[node.name] = node.value self.params[node.name] = node.value
if dtype != "float32": if 0 not in shape:
self.params[node.name] = node.value.astype("float32") if dtype != "float32":
self.paddle_graph.add_layer( self.params[node.name] = node.value.astype("float32")
"self.create_parameter",
inputs={},
outputs=[node.name],
shape=shape,
attr=string(node.name))
if dtype != "float32":
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.cast", "self.create_parameter",
inputs={"x": node.name}, inputs={},
outputs=[node.name], outputs=[node.name],
dtype=string(dtype)) shape=shape,
attr=string(node.name))
if dtype != "float32":
self.paddle_graph.add_layer(
kernel="paddle.cast",
inputs={"x": node.name},
outputs=[node.name],
dtype=string(dtype))
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0]) input = self.graph.get_node(node.layer.input[0])
......
...@@ -193,6 +193,8 @@ class PatternMatcher(object): ...@@ -193,6 +193,8 @@ class PatternMatcher(object):
continue continue
update(new_layer_id_in, pattern_layer_id_in) update(new_layer_id_in, pattern_layer_id_in)
if pattern.edges_out.get(pattern_layer_id, 0) != 0: if pattern.edges_out.get(pattern_layer_id, 0) != 0:
if layer_id not in graph.edges_out:
return False
if len(pattern.edges_out[pattern_layer_id]) != \ if len(pattern.edges_out[pattern_layer_id]) != \
len(graph.edges_out[layer_id]): len(graph.edges_out[layer_id]):
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册