From 32289c7471c09ccab187b340b4f20bc1a01e368e Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Mon, 26 Aug 2019 16:41:02 +0800 Subject: [PATCH] add resize op --- x2paddle/op_mapper/tf_op_mapper.py | 60 +++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index ba33d6d..87dfccf 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -17,6 +17,7 @@ from x2paddle.core.op_mapper import OpMapper from x2paddle.core.util import * import inspect import numpy +import sys # compute padding size for SAME mode @@ -83,18 +84,31 @@ class TFOpMapper(OpMapper): del self.graph.input_nodes[idx] print("Total nodes: {}".format(len(self.graph.topo_sort))) + unsupported_ops = set() for node_name in self.graph.topo_sort: node = self.graph.get_node(node_name) op = node.layer_type if op in self.directly_map_ops: + if len(unsupported_ops) > 0: + continue self.directly_map(node) elif op in self.elementwise_ops: + if len(unsupported_ops) > 0: + continue self.elementwise_map(node) elif hasattr(self, op): + if len(unsupported_ops) > 0: + continue func = getattr(self, op) func(node) else: - raise Exception("OP: [{}] not support yet".format(op)) + unsupported_ops.add(op) + if len(unsupported_ops) > 0: + print("=========={} Ops are not supported yet======".format( + len(unsupported_ops))) + for op in unsupported_ops: + print("========== {} ==========".format(op)) + sys.exit(-1) def directly_map(self, node): assert node.layer_type in self.directly_map_ops @@ -773,7 +787,15 @@ class TFOpMapper(OpMapper): begin = [begin[i] for i in [0, 3, 1, 2]] end = [end[i] for i in [0, 3, 1, 2]] - attr = {"axes": range(len(strides)), "starts": begin, "ends": end} + for i in range(len(end)): + if end[i] == 0: + end[i] = 999999 + + attr = { + "axes": [i for i in range(len(strides))], + "starts": begin, + "ends": end + } node.fluid_code.add_layer("slice", inputs=input, output=node, @@ -955,3 +977,37 @@ class TFOpMapper(OpMapper): inputs=input, output=node, param_attr=attr) + + def ResizeNearestNeighbor(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + resize_shape = self.graph.get_node(node.layer.input[1], copy=True) + self.omit_nodes.append(resize_shape.layer_name) + if resize_shape.layer_type == "Const": + resize_shape = resize_shape.value.tolist() + else: + resize_shape = self.decoder.infer_shape_tensor(resize_shape) + align_corners = node.get_attr("align_corners") + attr = {"align_corners": align_corners, "out_shape": resize_shape} + node.fluid_code.add_layer("resize_nearest", + inputs=input, + output=node, + param_attr=attr) + + def ResizeBilinear(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + resize_shape = self.graph.get_node(node.layer.input[1], copy=True) + self.omit_nodes.append(resize_shape.layer_name) + if resize_shape.layer_type == "Const": + resize_shape = resize_shape.value.tolist() + else: + resize_shape = self.decoder.infer_shape_tensor(resize_shape) + align_corners = node.get_attr("align_corners") + attr = { + "align_corners": align_corners, + "out_shape": resize_shape, + "align_mode": 1 + } + node.fluid_code.add_layer("resize_bilinear", + inputs=input, + output=node, + param_attr=attr) -- GitLab