From 144fa8326060296f1732b1edfffb6d91d6e049e1 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Mon, 4 Nov 2019 19:33:05 +0800 Subject: [PATCH] -aadd mirrorpad --- x2paddle/op_mapper/tf_op_mapper.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index b36635c..88fb339 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -762,6 +762,29 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) + def MirrorPad(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + paddings = self.graph.get_node(node.layer.input[1], copy=True) + assert paddings.layer_type == "Const", "Padding should be Const" + self.add_omit_nodes(paddings.layer_name, node.layer_name) + paddings = paddings.value.flatten().tolist() + mode = node.get_attr("mode").decode() + assert mode == "REFLECT", "Only support 'REFLECT` mode in MirrorPad" + if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4: + paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]] + + pad_op = "pad" + if len(input.out_shapes[0]) == 4: + if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0: + paddings = paddings[4:] + pad_op = "pad2d" + attr = {"paddings": paddings, "mode": string("reflect")} + node.fluid_code.add_layer(pad_op, + inputs=input, + output=node, + param_attr=attr) + + def Range(self, node): start = self.graph.get_node(node.layer.input[0], copy=True) limit = self.graph.get_node(node.layer.input[1], copy=True) -- GitLab