diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index 97a10ee180d453ce7d8a39d7e742e17d9b922e26..b0b7a194189b7e818873ccd486046915c05622ea 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -110,7 +110,6 @@ class Graph(object): del self.node_map[input].inputs[idx] del self.node_map[node_name] - print("remove topo", node_name) idx = self.topo_sort.index(node_name) del self.topo_sort[idx] diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index f0c59dfbf522ba88f4dd7df7b398be7ed4c79212..7956508e1714092bcb09bc80301c7868f0433fb1 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -25,16 +25,18 @@ import sys class TFGraphNode(GraphNode): def __init__(self, layer, layer_name=None): if layer_name is None: - super(TFGraphNode, self).__init__(layer, - layer.name.replace('/', '_')) + super(TFGraphNode, + self).__init__(layer, + layer.name.replace('/', '_').replace('-', '_')) else: - super(TFGraphNode, self).__init__(layer, - layer_name.replace('/', '_')) + super(TFGraphNode, + self).__init__(layer, + layer_name.replace('/', '_').replace('-', '_')) self.layer_type = layer.op self.fluid_code = FluidCode() - self.dtype_map = {1: "float32", 3: "int32", 9: "int64"} + self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"} @property def out_shapes(self): @@ -89,11 +91,12 @@ class TFGraph(Graph): def build(self): for layer in self.model.node: - self.node_map[layer.name.replace('/', '_')] = TFGraphNode(layer) + self.node_map[layer.name.replace('/', '_').replace( + '-', '_')] = TFGraphNode(layer) for layer_name, node in self.node_map.items(): for in_node in node.layer.input: - in_node = in_node.replace('/', '_') + in_node = in_node.replace('/', '_').replace('-', '_') if in_node not in self.node_map: if in_node.strip().split(':')[0] in self.node_map: self.connect(in_node.strip().split(':')[0], layer_name) @@ -112,7 +115,7 @@ class TFGraph(Graph): def get_node(self, node_name, copy=False): items = node_name.strip().split(':') - items[0] = items[0].replace('/', '_') + items[0] = items[0].replace('/', '_').replace('-', '_') if items[0] in self.identity_map: items[0] = self.identity_map[items[0]] new_node_name = ":".join(items) @@ -163,11 +166,12 @@ def check_input_shape(graph_def): continue graph_node = TFGraphNode(layer) dtype = graph_node.dtype - # print("shape:", graph_node.out_shapes) if not graph_node.get_attr("shape"): - sys.stderr.write("Unknown shape for input tensor[{}]\n".format( - layer.name)) - shape = input("Please define shape of input here: ") + sys.stderr.write( + "\nUnknown shape for input tensor[tensor name: \"{}\"]\n". + format(layer.name)) + shape = input( + "Please define shape of input here(e.g. None,224,224,3): ") shape = [ None if dim == "None" else int(dim) for dim in shape.strip().split(',') @@ -187,6 +191,7 @@ class TFDecoder(object): graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) input_map = check_input_shape(graph_def) + self._fix_output_shape(graph_def) sess.graph.as_default() tf.import_graph_def(graph_def, name='', input_map=input_map) @@ -194,3 +199,9 @@ class TFDecoder(object): self.tf_graph = TFGraph(sess.graph._as_graph_def(add_shapes=True)[0]) self.tf_graph.build() + + def _fix_output_shape(self, graph): + for i in range(len(graph.node)): + node = graph.node[i] + if node.op == "swish_f32": + graph.node[i].attr['_disable_call_shape_inference'].b = False diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index a0093776d884d1f73530098a3c136deeeb73f7d6..23e4cf9026ec4aafeb2253c90555495490e6970d 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -54,7 +54,8 @@ class TFOpMapper(OpMapper): attr = { 'dtype': string(dtype), 'shape': shape, - 'name': string(node.layer_name) + 'name': string(node.layer_name), + 'append_batch_size': False } node.fluid_code.add_layer("data", inputs=None, @@ -350,6 +351,7 @@ class TFOpMapper(OpMapper): param = self.graph.get_node(node.layer.input[1], copy=True) if param.layer_type == "Const": attr = {"shape": param.value.tolist()} + self.omit_nodes.append(param.layer_name) else: # Here is a trick method to solove tensor parameter in tensorflow assert len(param.out_shapes[0] @@ -425,3 +427,206 @@ class TFOpMapper(OpMapper): inputs=input, output=node, param_attr=None) + + def Sigmoid(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + node.fluid_code.add_layer("sigmoid", + inputs=input, + output=node, + param_attr=None) + + def Maximum(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {"x": x, "y": y} + node.fluid_code.add_layer("elementwise_max", + inputs=inputs, + output=node, + param_attr=None) + + def SplitV(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + num_sections = self.graph.get_node(node.layer.input[1], copy=True) + dim = self.graph.get_node(node.layer.input[2], copy=True) + assert num_sections.layer_type == "Const" + assert dim.layer_type == "Const" + self.omit_nodes.append(num_sections.layer_name) + self.omit_nodes.append(dim.layer_name) + attr = { + "num_or_sections": num_sections.value.tolist(), + "dim": dim.value + } + node.fluid_code.add_layer("split", + inputs=input, + output=node, + param_attr=attr) + + def Exp(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + node.fluid_code.add_layer("exp", + inputs=input, + output=node, + param_attr=None) + + def ConcatV2(self, node): + inputs = [ + self.graph.get_node(name, copy=True) + for name in node.layer.input[:-1] + ] + axis = self.graph.get_node(node.layer.input[-1], copy=True) + assert axis.layer_type == "Const" + self.omit_nodes.append(axis.layer_name) + attr = {"axis": axis.value} + node.fluid_code.add_layer("concat", + inputs=inputs, + output=node, + param_attr=attr) + + def Tile(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + expand_times = self.graph.get_node(node.layer.input[1], copy=True) + assert expand_times.layer_type == "Const" + self.omit_nodes.append(expand_times.layer_name) + attr = {"expand_times": expand_times.value.tolist()} + node.fluid_code.add_layer("expand", + inputs=input, + output=node, + param_attr=attr) + + def Pack(self, node): + inputs = [ + self.graph.get_node(name, copy=True) for name in node.layer.input + ] + node.fluid_code.add_layer("stack", + inputs=inputs, + output=node, + param_attr=None) + + def Pad(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + paddings = self.graph.get_node(Node.layer.input[1], copy=True) + assert paddings.layer_type == "Const", "Padding should be Const" + self.omit_nodes.append(paddings.layer_name) + attr = {"paddings": paddings.value.tolist()} + node.fluid_code.add_layer("pad", + inputs=input, + output=node, + param_attr=attr) + +# def ResizeNearestNeighbor(self, node): +# pass + + def Range(self, node): + start = self.graph.get_node(node.layer.input[0], copy=True) + limit = self.graph.get_node(node.layer.input[1], copy=True) + delta = self.graph.get_node(node.layer.input[2], copy=True) + if start.layer_type == "Const": + self.omit_nodes.append(start.layer_name) + start = start.value + if limit.layer_type == "Const": + self.omit_nodes.append(limit.layer_name) + limit = limit.value + if delta.layer_type == "Const": + self.omit_nodes.append(delta.layer_name) + delta = delta.value + inputs = {"start": start, "end": limit, "step": delta} + attr = {"dtype": string(node.dtype)} + node.fluid_code.append("range", + inputs=inputs, + output=node, + param_attr=None) + + +# def Fill(self, node): +# shape = self.graph.get_node(node.layer + + def Mul(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {"x": x, "y": y} + node.fluid_code.add_layer("elementwise_mul", + inputs=inputs, + output=node, + param_attr=None) + + def Sub(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {"x": x, "y": y} + node.fluid_code.add_layer("elementwise_sub", + inputs=inputs, + output=node, + param_attr=None) + + def Rsqrt(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + node.fluid_code.add_layer("rsqrt", + inputs=input, + output=node, + param_attr=None) + + def swish_f32(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + node.fluid_code.add_layer("sigmoid", + inputs=input, + output=node, + param_attr=None) + inputs = {"x": input, "y": node} + node.fluid_code.add_layer("elementwise_mul", + inputs=inputs, + output=node, + param_attr=None) + + def Mean(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + reduce_idx = self.graph.get_node(node.layer.input[1], copy=True) + assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]" + keep_dims = node.get_attr("keep_dims") + attr = {"dim": reduce_idx.value.tolist(), "keep_dim": keep_dims} + node.fluid_code.add_layer("reduce_mean", + inputs=input, + output=node, + param_attr=attr) + + def MatMul(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + transpose_a = node.get_attr('transpose_a') + transpose_b = node.get_attr('transpose_b') + inputs = {"x": x, "y": y} + attr = {"transpose_x": transpose_a, "transpose_y": transpose_b} + node.fluid_code.add_layer("matmul", + inputs=inputs, + output=node, + param_attr=attr) + + def ArgMax(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + axis = self.graph.get_node(node.layer.input[1], copy=True) + assert axis.layer_type == "Const", "ArgMax only support Const parameter" + self.omit_nodes.append(axis.layer_name) + attr = {"axis": axis.value} + node.fluid_code.add_layer("argmax", + inputs=input, + output=node, + param_attr=attr) + + def StridedSlice(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + begin = self.graph.get_node(node.layer.input[1], copy=True) + end = self.graph.get_node(node.layer.input[2], copy=True) + strides = self.graph.get_node(node.layer.input[3], copy=True) + assert begin.layer_type == "Const" + assert end.layer_type == "Const" + assert strides.layer_type == "Const" + self.omit_nodes.append(begin.layer_name) + self.omit_nodes.append(end.layer_name) + self.omit_nodes.append(strides.layer_name) + strides = strides.value.tolist() + assert len(set(strides)) == 1 and strides[0] == 1 + + attr = {"starts": begin.value.tolist(), "ends": end.value.tolist()} + node.fluid_code.add_layer("slice", + inputs=input, + output=node, + param_attr=attr)