From a829051a746ea766a20e54c64b76659ffb782701 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Fri, 19 Jul 2019 14:18:26 +0800 Subject: [PATCH] structure modify --- x2paddle/core/graph.py | 1 - x2paddle/core/op_mapper.py | 15 +++++++++++++++ x2paddle/op_mapper/caffe_op_mapper.py | 5 +++++ x2paddle/op_mapper/tf_op_mapper.py | 5 +++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index 86aefd2..10eaede 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -61,7 +61,6 @@ class Graph(object): num_inputs = dict() for name, node in self.node_map.items(): num_inputs[name] = len(node.inputs) - print(len(self.node_map)) self.topo_sort = self.input_nodes[:] idx = 0 diff --git a/x2paddle/core/op_mapper.py b/x2paddle/core/op_mapper.py index 5b8049a..a4dc24b 100644 --- a/x2paddle/core/op_mapper.py +++ b/x2paddle/core/op_mapper.py @@ -23,6 +23,21 @@ class OpMapper(object): self.net_code = list() self.weights = dict() + def op_checker(self): + unsupported_ops = set() + for node_name in self.graph.topo_sort: + node = self.graph.get_node(node_name) + op = node.layer_type + if not hasattr(self, op): + unsupported_ops.add(op) + if len(unsupported_ops) == 0: + return True + else: + print("There are {} ops not supported yet, list as below") + for op in unsupported_ops: + print(op) + return False + def add_codes(self, codes, indent=0): if isinstance(codes, list): for code in codes: diff --git a/x2paddle/op_mapper/caffe_op_mapper.py b/x2paddle/op_mapper/caffe_op_mapper.py index 9535014..f6f96d2 100644 --- a/x2paddle/op_mapper/caffe_op_mapper.py +++ b/x2paddle/op_mapper/caffe_op_mapper.py @@ -31,6 +31,11 @@ class CaffeOpMapper(OpMapper): def run(self): print("Total nodes: {}".format(len(self.graph.topo_sort))) + + # check if ops in model are all supported + if not self.op_checker(): + raise Exception("Model are not supported yet.") + for node_name in self.graph.topo_sort: node = self.graph.get_node(node_name) op = node.layer_type diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index 5d79517..54bec04 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -27,6 +27,11 @@ class TFOpMapper(OpMapper): def run(self): print("Total nodes: {}".format(len(self.graph.topo_sort))) + + # check if ops in model are all supported + if not self.op_checker(): + raise Exception("Model are not supported yet.") + for node_name in self.graph.topo_sort: node = self.graph.get_node(node_name) op = node.layer_type -- GitLab