diff --git a/x2paddle/core/fluid_code.py b/x2paddle/core/fluid_code.py index 3c86f42b24f352846358fe2f8b3842a198e134da..72cfd1acb5fcd90fe629b9a92e93b890449e8aad 100644 --- a/x2paddle/core/fluid_code.py +++ b/x2paddle/core/fluid_code.py @@ -60,6 +60,9 @@ class FluidCode(object): # note should be string self.layers.append(note) + def clear(self): + self.layers = list() + def gen_codes(self): codes = list() for layer in self.layers: diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index cdedf9e32a1346532a3c99536f62ff92457995cc..d4fec45e55b597f0a183733b528b708889474d5a 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -74,7 +74,12 @@ class Graph(object): def get_node(self, name): if name not in self.node_map: - raise Exception("Graph doesn't have node [%s]." % name) + if name.split(':')[0] in self.node_map: + name_prefix, idx = name.split(':') + self.node_map[name_prefix].index = int(idx) + return self.node_map[name_prefix] + else: + raise Exception("Graph doesn't have node [%s]." % name) else: return self.node_map[name] diff --git a/x2paddle/emitter/tf_emitter.py b/x2paddle/emitter/tf_emitter.py index 9597679b93a97102c75e8842386992bf789b9ee8..74077af2bc83c958f0be620a894eccecd7eca963 100644 --- a/x2paddle/emitter/tf_emitter.py +++ b/x2paddle/emitter/tf_emitter.py @@ -15,6 +15,7 @@ from x2paddle.parser.tf_parser import TFGraph from x2paddle.core.emitter import Emitter from x2paddle.core.fluid_code import FluidCode +from x2paddle.core.util import * class TFEmitter(Emitter): @@ -22,7 +23,7 @@ class TFEmitter(Emitter): super(TFEmitter, self).__init__() self.parser = parser self.graph = parser.tf_graph - self.fluid_code = FluidCode() + self.weights = dict() def run(self): print("Total nodes: {}".format(len(self.graph.topo_sort))) @@ -33,16 +34,66 @@ class TFEmitter(Emitter): emit_func = getattr(self, op) emit_func(node) + for i in range(len(self.graph.topo_sort)): + node_name = self.graph.topo_sort[i] + node = self.graph.get_node(node_name) + for layer in node.fluid_code.layers: + print(layer.get_code()) + def Placeholder(self, node): shape = node.out_shapes[0] dtype = node.dtype attr = { - 'dtype': '\{}\''.format(dtype), + 'dtype': string(dtype), 'shape': shape, - 'name': '\'{}\''.format(node.layer_name) + 'name': string(node.layer_name) } - self.fluid_code.add_layer("data", - inputs=inputs, + node.fluid_code.add_layer("data", + inputs=None, + output=node, + param_attr=attr) + + def Const(self, node): + shape = node.out_shapes[0] + dtype = node.dtype + value = node.value + initializer = "Constant(0.0)" + if len(shape) == 0: + assert value.size == 1, "Unexpected situation happend" + shape = [1] + initializer = "Constant({})".format(value) + + attr = { + 'dtype': string(dtype), + 'shape': shape, + 'name': string(node.layer_name), + 'default_initializer': initializer + } + node.fluid_code.add_layer("create_parameter", + inputs=None, + output=node, + param_attr=attr) + + def Transpose(self, node): + input = self.graph.get_node(node.layer.input[0]) + perm = self.graph.get_node(node.layer.input[1]) + perm.fluid_code.clear() + perm = perm.value.tolist() + + attr = {'perm': perm} + node.fluid_code.add_layer("transpose", + inputs=input, output=node, param_attr=attr) - print(self.fluid_code.layers[0].get_code()) + + def RealDiv(self, node): + x = self.graph.get_node(node.layer.input[0]) + y = self.graph.get_node(node.layer.input[1]) + inputs = {'x': x, 'y': y} + node.fluid_code.add_layer("elementwise_div", + inputs=inputs, + output=node, + param_attr=None) + + def Fc(self, node): + self.weight['asdf'] = np.tranpose(node.kerneln[1, 0]) diff --git a/x2paddle/parser/tf_parser.py b/x2paddle/parser/tf_parser.py index 55cf14ace55dd5d9528bf795652ce8fcdee8f84c..791af375238dd6a70db6636706d799e15ba56c3c 100644 --- a/x2paddle/parser/tf_parser.py +++ b/x2paddle/parser/tf_parser.py @@ -13,6 +13,8 @@ # limitations under the License. from x2paddle.core.graph import GraphNode, Graph +from x2paddle.core.fluid_code import FluidCode +from tensorflow.python.framework import tensor_util from tensorflow.python.platform import gfile import tensorflow as tf import copy @@ -24,7 +26,9 @@ class TFGraphNode(GraphNode): super(TFGraphNode, self).__init__(layer, layer.name) else: super(TFGraphNode, self).__init__(layer, layer_name) + self.layer_type = layer.op + self.fluid_code = FluidCode() self.dtype_map = {1: "float32", 3: "int32", 9: "int64"} @@ -44,6 +48,14 @@ class TFGraphNode(GraphNode): raise Exception("Dtype[{}] not in dtype_map".format(dtype)) return self.dtype_map[dtype] + @property + def value(self): + assert self.layer_type == "Const", "Only Const node has value." + + attr = self.layer.attr['value'] + field = getattr(attr, attr.WhichOneof('value')) + return tensor_util.MakeNdarray(field) + class TFGraph(Graph): def __init__(self, model):