提交 097fe706 编写于 作者: S SunAhong1993

fix the bug

上级 d2674430
...@@ -38,28 +38,34 @@ class PyTorchOpMapper(OpMapper): ...@@ -38,28 +38,34 @@ class PyTorchOpMapper(OpMapper):
self.scope_name2id = dict() self.scope_name2id = dict()
self.inputs_info = dict() self.inputs_info = dict()
# 转换 # 转换
self.check_op(decoder.graph) if not self.op_checker(decoder.graph):
raise Exception("Model is not supported yet.")
self.paddle_graph, _ = self.traverse(decoder.graph) self.paddle_graph, _ = self.traverse(decoder.graph)
self.paddle_graph.set_inputs_info(self.inputs_info) self.paddle_graph.set_inputs_info(self.inputs_info)
def check_op(self, script_graph): def op_checker(self, script_graph):
def _update_op_list(graph): def _update_op_list(graph):
for node in graph.nodes(): for node in graph.nodes():
op_list.append(node.kind()) op_list.append(node.kind())
for block in node.blocks(): for block in node.blocks():
_update_op_list(block) _update_op_list(block)
op_list = list() op_list = list()
_update_op_list(script_graph) _update_op_list(script_graph)
op_list = list(set(op_list)) op_list = list(set(op_list))
unsupported_op_list = [] unsupported_ops = []
for op in op_list: for op in op_list:
func_name = op.replace('::', '_') func_name = op.replace('::', '_')
if not (hasattr(prim, func_name) or hasattr(aten, func_name)): if not (hasattr(prim, func_name) or hasattr(aten, func_name)):
unsupported_op_list.append(op) unsupported_ops.append(op)
if len(unsupported_op_list) > 0: if len(unsupported_ops) == 0:
raise Exception("The kind {} in model is not supported yet.".format( return True
unsupported_op_list)) else:
if len(unsupported_ops) > 0:
print("\n========= {} OPs are not supported yet ===========".format(
len(unsupported_ops)))
for op in unsupported_ops:
print("========== {} ============".format(op))
return False
def traverse(self, script_graph, parent_layer=None): def traverse(self, script_graph, parent_layer=None):
# 用于获取graph的输入 # 用于获取graph的输入
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册