diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py index cb210956b7804fd76fd9ec1bc54a8c43c0bf6dac..118790772565bc565115adeffd65af22cafb6adf 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py @@ -38,28 +38,34 @@ class PyTorchOpMapper(OpMapper): self.scope_name2id = dict() self.inputs_info = dict() # 转换 - self.check_op(decoder.graph) + if not self.op_checker(decoder.graph): + raise Exception("Model is not supported yet.") self.paddle_graph, _ = self.traverse(decoder.graph) self.paddle_graph.set_inputs_info(self.inputs_info) - def check_op(self, script_graph): + def op_checker(self, script_graph): def _update_op_list(graph): for node in graph.nodes(): op_list.append(node.kind()) for block in node.blocks(): _update_op_list(block) - op_list = list() _update_op_list(script_graph) op_list = list(set(op_list)) - unsupported_op_list = [] + unsupported_ops = [] for op in op_list: func_name = op.replace('::', '_') if not (hasattr(prim, func_name) or hasattr(aten, func_name)): - unsupported_op_list.append(op) - if len(unsupported_op_list) > 0: - raise Exception("The kind {} in model is not supported yet.".format( - unsupported_op_list)) + unsupported_ops.append(op) + if len(unsupported_ops) == 0: + return True + else: + if len(unsupported_ops) > 0: + print("\n========= {} OPs are not supported yet ===========".format( + len(unsupported_ops))) + for op in unsupported_ops: + print("========== {} ============".format(op)) + return False def traverse(self, script_graph, parent_layer=None): # 用于获取graph的输入