diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index a2e8ec0e78c20b1c8d1028a3b558e1dfc97ebe11..97a82581b4a65d12fb80cb2f8c14d660abcfa1dc 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -877,12 +877,12 @@ class TFOpMapper(OpMapper): num_or_sections=num_split, dim=dim) - def Slice(self, node): +def Slice(self, node): input = self.graph.get_node(node.layer.input[0]) begin = self.graph.get_node(node.layer.input[1]) size = self.graph.get_node(node.layer.input[2]) - inputs = {"input": input.name} + inputs = {"x": input.name} attrs = {} if begin.layer_type == "Const": begin = begin.value.tolist() @@ -890,7 +890,7 @@ class TFOpMapper(OpMapper): else: # shape = begin.out_shapes[0] # reshape_name = gen_name("slice", "reshape") - # program.add_layer( + # self.paddle_graph.add_layer( # kernel="fluid.layers.reshape", # inputs={"x": begin.name}, # outputs=[reshape_name], @@ -901,32 +901,20 @@ class TFOpMapper(OpMapper): if size.layer_type == "Const": size = size.value.tolist() attrs['shape'] = size - shape = size else: shape = size.out_shapes[0] -# reshape_name = gen_name("slice", "reshape") -# program.add_layer( -# kernel="fluid.layers.reshape", -# inputs={"x": size.name}, -# outputs=[reshape_name], -# shape=shape) -# inputs['shape'] = reshape_name - - for i, s in enumerate(shape): - if s < 0: - shape[i] = 32767 - program.add_layer( - kernel="fluid.layers.slice", + reshape_name = gen_name("slice", "reshape") + self.paddle_graph.add_layer( + kernel="fluid.layers.reshape", + inputs={"x": size.name}, + outputs=[reshape_name], + shape=shape) + inputs['shape'] = reshape_name + self.paddle_graph.add_layer( + kernel="fluid.layers.crop_tensor", inputs=inputs, outputs=[node.name], - axes=list(range(len(attrs['offsets']))), - starts=attrs['offsets'], - ends=[attrs['offsets'][i] + shape[i] for i in range(len(shape))]) -# program.add_layer( -# kernel="fluid.layers.crop_tensor", -# inputs=inputs, -# outputs=[node.name], -# **attrs) + **attrs) def ResizeNearestNeighbor(self, node): input = self.graph.get_node(node.layer.input[0])