From aa98edf9bb9f37115ff6f6c3218846fc99d17d90 Mon Sep 17 00:00:00 2001 From: mamingjie-China Date: Tue, 4 Aug 2020 14:46:41 +0800 Subject: [PATCH] update --- x2paddle/op_mapper/tf_op_mapper_nhwc.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index 8ea2a45..5e1a54b 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -199,17 +199,21 @@ class TFOpMapperNHWC(OpMapper): def Fill(self, node): dims = self.graph.get_node(node.layer.input[0]) input_value = self.graph.get_node(node.layer.input[1]) + inputs = dict() + attr = dict() assert input_value.layer_type == "Const", "Value of fill OP should be Const" + if dims.layer_type == "Const": + attr["shape"] = dims.value.tolist() + else: + inputs["shape"] = dims.name + attr["dtype"] = string(input_value.dtype) + attr["value"] = input_value.value - input_value = input_value.value - input_dtype = string(input_value.dtype) program.add_layer( "fluid.layers.fill_constant", - inputs={}, + inputs=inputs, outputs=[node.name], - shape=dims, - dtype=string(input_dtype), - value=input_value) + **attr) def DepthToSpace(self, node): input = self.graph.get_node(node.layer.input[0]) @@ -251,8 +255,8 @@ class TFOpMapperNHWC(OpMapper): shape=[0, c, h, w]) program.add_layer( - kernel="fluid.layers.pixed_shuffle", - inputs={"input": reshape_name}, + kernel="fluid.layers.pixel_shuffle", + inputs={"x": reshape_name}, outputs=[node.name], upscale_factor=block_size) @@ -309,9 +313,9 @@ class TFOpMapperNHWC(OpMapper): data_format = node.get_attr("data_format").decode() pad_mode = node.get_attr("padding").decode() if data_format == "NHWC": - n, c, h, w = input.out_shapes[0] - else: n, h, w, c = input.out_shapes[0] + else: + n, c, h, w = input.out_shapes[0] if kernel.layer_type == 'Const': kernel_value = kernel.value -- GitLab