From 6927786fe0b4d27fe07e402bde3bc73f11196112 Mon Sep 17 00:00:00 2001 From: "jiangjiajun@baidu.com" Date: Wed, 18 Mar 2020 11:45:21 +0800 Subject: [PATCH] add efficient support: --- x2paddle/convert.py | 5 ++++- x2paddle/decoder/tf_decoder.py | 16 +++++++++++++--- x2paddle/op_mapper/tf_op_mapper.py | 8 +++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 901fbd7..c0c8fb9 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -211,7 +211,10 @@ def main(): try: import paddle v0, v1, v2 = paddle.__version__.split('.') - if int(v0) != 1 or int(v1) < 6: + print("paddle.__version__ = {}".format(paddle.__version__)) + if v0 == '0' and v1 == '0' and v2 == '0': + print("[WARNING] You are use develop version of paddlepaddle") + elif int(v0) != 1 or int(v1) < 6: print("[ERROR] paddlepaddle>=1.6.0 is required") return except: diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index cf3b5ac..97c5d48 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -48,7 +48,10 @@ class TFGraphNode(GraphNode): @property def out_shapes(self): - values = self.layer.attr["_output_shapes"].list.shape + if self.layer_type == "OneShotIterator": + values = self.layer.attr["output_shapes"].list.shape + else: + values = self.layer.attr["_output_shapes"].list.shape out_shapes = list() for value in values: shape = [dim.size for dim in value.dim] @@ -62,6 +65,8 @@ class TFGraphNode(GraphNode): dtype = self.layer.attr[k].type if dtype > 0: break + if dtype == 0: + dtype = self.layer.attr['output_types'].list.type[0] if dtype not in self.dtype_map: raise Exception("Dtype[{}] not in dtype_map".format(dtype)) return self.dtype_map[dtype] @@ -226,7 +231,7 @@ class TFGraph(Graph): def _remove_identity_node(self): identity_ops = [ 'Identity', 'StopGradient', 'Switch', 'Merge', - 'PlaceholderWithDefault' + 'PlaceholderWithDefault', 'IteratorGetNext' ] identity_node = list() for node_name, node in self.node_map.items(): @@ -317,7 +322,7 @@ class TFDecoder(object): graph_def = cp.deepcopy(graph_def) input_map = dict() for layer in graph_def.node: - if layer.op != "Placeholder": + if layer.op != "Placeholder" and layer.op != "OneShotIterator": continue graph_node = TFGraphNode(layer) dtype = graph_node.layer.attr['dtype'].type @@ -335,6 +340,11 @@ class TFDecoder(object): if shape.count(-1) > 1: need_define_shape = 2 + if need_define_shape == 1: + shape = graph_node.out_shapes[0] + if len(shape) > 0 and shape.count(-1) < 2: + need_define_shape = 0 + if need_define_shape > 0: shape = None if graph_node.get_attr("shape"): diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index 0512c6e..45fd2e8 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -85,7 +85,7 @@ class TFOpMapper(OpMapper): not_placeholder = list() for name in self.graph.input_nodes: - if self.graph.get_node(name).layer_type != "Placeholder": + if self.graph.get_node(name).layer_type != "Placeholder" and self.graph.get_node(name).layer_type != "OneShotIterator": not_placeholder.append(name) for name in not_placeholder: idx = self.graph.input_nodes.index(name) @@ -287,6 +287,9 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) + def OneShotIterator(self, node): + return self.Placeholder(node) + def Const(self, node): shape = node.out_shapes[0] dtype = node.dtype @@ -492,6 +495,9 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) + def FusedBatchNormV3(self, node): + return self.FusedBatchNorm(node) + def DepthwiseConv2dNative(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True) -- GitLab