提交 bbd0ed0e 编写于 作者: S SunAhong1993

fix the bug

上级 c1142687
......@@ -224,20 +224,21 @@ class TFOpMapper(OpMapper):
return
self.params[node.name] = node.value
if dtype != "float32":
self.params[node.name] = node.value.astype("float32")
self.paddle_graph.add_layer(
"self.create_parameter",
inputs={},
outputs=[node.name],
shape=shape,
attr=string(node.name))
if dtype != "float32":
if 0 not in shape:
if dtype != "float32":
self.params[node.name] = node.value.astype("float32")
self.paddle_graph.add_layer(
kernel="paddle.cast",
inputs={"x": node.name},
outputs=[node.name],
dtype=string(dtype))
"self.create_parameter",
inputs={},
outputs=[node.name],
shape=shape,
attr=string(node.name))
if dtype != "float32":
self.paddle_graph.add_layer(
kernel="paddle.cast",
inputs={"x": node.name},
outputs=[node.name],
dtype=string(dtype))
def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0])
......
......@@ -193,6 +193,8 @@ class PatternMatcher(object):
continue
update(new_layer_id_in, pattern_layer_id_in)
if pattern.edges_out.get(pattern_layer_id, 0) != 0:
if layer_id not in graph.edges_out:
return False
if len(pattern.edges_out[pattern_layer_id]) != \
len(graph.edges_out[layer_id]):
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册