From 097fe706d58a4ceae4ead43339eccd1062b8bdbe Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Wed, 25 Nov 2020 19:06:52 +0800 Subject: [PATCH] fix the bug --- .../pytorch2paddle/pytorch_op_mapper.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py index cb21095..1187907 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py @@ -38,28 +38,34 @@ class PyTorchOpMapper(OpMapper): self.scope_name2id = dict() self.inputs_info = dict() # 转换 - self.check_op(decoder.graph) + if not self.op_checker(decoder.graph): + raise Exception("Model is not supported yet.") self.paddle_graph, _ = self.traverse(decoder.graph) self.paddle_graph.set_inputs_info(self.inputs_info) - def check_op(self, script_graph): + def op_checker(self, script_graph): def _update_op_list(graph): for node in graph.nodes(): op_list.append(node.kind()) for block in node.blocks(): _update_op_list(block) - op_list = list() _update_op_list(script_graph) op_list = list(set(op_list)) - unsupported_op_list = [] + unsupported_ops = [] for op in op_list: func_name = op.replace('::', '_') if not (hasattr(prim, func_name) or hasattr(aten, func_name)): - unsupported_op_list.append(op) - if len(unsupported_op_list) > 0: - raise Exception("The kind {} in model is not supported yet.".format( - unsupported_op_list)) + unsupported_ops.append(op) + if len(unsupported_ops) == 0: + return True + else: + if len(unsupported_ops) > 0: + print("\n========= {} OPs are not supported yet ===========".format( + len(unsupported_ops))) + for op in unsupported_ops: + print("========== {} ============".format(op)) + return False def traverse(self, script_graph, parent_layer=None): # 用于获取graph的输入 -- GitLab