From bbd0ed0e6f50fb5dccf188d70152409a1576e4b8 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Wed, 11 Nov 2020 14:27:33 +0800 Subject: [PATCH] fix the bug --- .../dygraph/tf2paddle/tf_op_mapper.py | 27 ++++++++++--------- x2paddle/optimizer/pattern_matcher.py | 2 ++ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py index 1518e6f..773b4e2 100644 --- a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py @@ -224,20 +224,21 @@ class TFOpMapper(OpMapper): return self.params[node.name] = node.value - if dtype != "float32": - self.params[node.name] = node.value.astype("float32") - self.paddle_graph.add_layer( - "self.create_parameter", - inputs={}, - outputs=[node.name], - shape=shape, - attr=string(node.name)) - if dtype != "float32": + if 0 not in shape: + if dtype != "float32": + self.params[node.name] = node.value.astype("float32") self.paddle_graph.add_layer( - kernel="paddle.cast", - inputs={"x": node.name}, - outputs=[node.name], - dtype=string(dtype)) + "self.create_parameter", + inputs={}, + outputs=[node.name], + shape=shape, + attr=string(node.name)) + if dtype != "float32": + self.paddle_graph.add_layer( + kernel="paddle.cast", + inputs={"x": node.name}, + outputs=[node.name], + dtype=string(dtype)) def Transpose(self, node): input = self.graph.get_node(node.layer.input[0]) diff --git a/x2paddle/optimizer/pattern_matcher.py b/x2paddle/optimizer/pattern_matcher.py index cc86025..4487f51 100644 --- a/x2paddle/optimizer/pattern_matcher.py +++ b/x2paddle/optimizer/pattern_matcher.py @@ -193,6 +193,8 @@ class PatternMatcher(object): continue update(new_layer_id_in, pattern_layer_id_in) if pattern.edges_out.get(pattern_layer_id, 0) != 0: + if layer_id not in graph.edges_out: + return False if len(pattern.edges_out[pattern_layer_id]) != \ len(graph.edges_out[layer_id]): return False -- GitLab