提交 0e572fc5 编写于 作者: J jiangjiajun

add code

上级 51fcdabb
...@@ -60,6 +60,9 @@ class FluidCode(object): ...@@ -60,6 +60,9 @@ class FluidCode(object):
# note should be string # note should be string
self.layers.append(note) self.layers.append(note)
def clear(self):
self.layers = list()
def gen_codes(self): def gen_codes(self):
codes = list() codes = list()
for layer in self.layers: for layer in self.layers:
......
...@@ -74,7 +74,12 @@ class Graph(object): ...@@ -74,7 +74,12 @@ class Graph(object):
def get_node(self, name): def get_node(self, name):
if name not in self.node_map: if name not in self.node_map:
raise Exception("Graph doesn't have node [%s]." % name) if name.split(':')[0] in self.node_map:
name_prefix, idx = name.split(':')
self.node_map[name_prefix].index = int(idx)
return self.node_map[name_prefix]
else:
raise Exception("Graph doesn't have node [%s]." % name)
else: else:
return self.node_map[name] return self.node_map[name]
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from x2paddle.parser.tf_parser import TFGraph from x2paddle.parser.tf_parser import TFGraph
from x2paddle.core.emitter import Emitter from x2paddle.core.emitter import Emitter
from x2paddle.core.fluid_code import FluidCode from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import *
class TFEmitter(Emitter): class TFEmitter(Emitter):
...@@ -22,7 +23,7 @@ class TFEmitter(Emitter): ...@@ -22,7 +23,7 @@ class TFEmitter(Emitter):
super(TFEmitter, self).__init__() super(TFEmitter, self).__init__()
self.parser = parser self.parser = parser
self.graph = parser.tf_graph self.graph = parser.tf_graph
self.fluid_code = FluidCode() self.weights = dict()
def run(self): def run(self):
print("Total nodes: {}".format(len(self.graph.topo_sort))) print("Total nodes: {}".format(len(self.graph.topo_sort)))
...@@ -33,16 +34,66 @@ class TFEmitter(Emitter): ...@@ -33,16 +34,66 @@ class TFEmitter(Emitter):
emit_func = getattr(self, op) emit_func = getattr(self, op)
emit_func(node) emit_func(node)
for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i]
node = self.graph.get_node(node_name)
for layer in node.fluid_code.layers:
print(layer.get_code())
def Placeholder(self, node): def Placeholder(self, node):
shape = node.out_shapes[0] shape = node.out_shapes[0]
dtype = node.dtype dtype = node.dtype
attr = { attr = {
'dtype': '\{}\''.format(dtype), 'dtype': string(dtype),
'shape': shape, 'shape': shape,
'name': '\'{}\''.format(node.layer_name) 'name': string(node.layer_name)
} }
self.fluid_code.add_layer("data", node.fluid_code.add_layer("data",
inputs=inputs, inputs=None,
output=node,
param_attr=attr)
def Const(self, node):
shape = node.out_shapes[0]
dtype = node.dtype
value = node.value
initializer = "Constant(0.0)"
if len(shape) == 0:
assert value.size == 1, "Unexpected situation happend"
shape = [1]
initializer = "Constant({})".format(value)
attr = {
'dtype': string(dtype),
'shape': shape,
'name': string(node.layer_name),
'default_initializer': initializer
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node,
param_attr=attr)
def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0])
perm = self.graph.get_node(node.layer.input[1])
perm.fluid_code.clear()
perm = perm.value.tolist()
attr = {'perm': perm}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
print(self.fluid_code.layers[0].get_code())
def RealDiv(self, node):
x = self.graph.get_node(node.layer.input[0])
y = self.graph.get_node(node.layer.input[1])
inputs = {'x': x, 'y': y}
node.fluid_code.add_layer("elementwise_div",
inputs=inputs,
output=node,
param_attr=None)
def Fc(self, node):
self.weight['asdf'] = np.tranpose(node.kerneln[1, 0])
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
from x2paddle.core.graph import GraphNode, Graph from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
import tensorflow as tf import tensorflow as tf
import copy import copy
...@@ -24,7 +26,9 @@ class TFGraphNode(GraphNode): ...@@ -24,7 +26,9 @@ class TFGraphNode(GraphNode):
super(TFGraphNode, self).__init__(layer, layer.name) super(TFGraphNode, self).__init__(layer, layer.name)
else: else:
super(TFGraphNode, self).__init__(layer, layer_name) super(TFGraphNode, self).__init__(layer, layer_name)
self.layer_type = layer.op self.layer_type = layer.op
self.fluid_code = FluidCode()
self.dtype_map = {1: "float32", 3: "int32", 9: "int64"} self.dtype_map = {1: "float32", 3: "int32", 9: "int64"}
...@@ -44,6 +48,14 @@ class TFGraphNode(GraphNode): ...@@ -44,6 +48,14 @@ class TFGraphNode(GraphNode):
raise Exception("Dtype[{}] not in dtype_map".format(dtype)) raise Exception("Dtype[{}] not in dtype_map".format(dtype))
return self.dtype_map[dtype] return self.dtype_map[dtype]
@property
def value(self):
assert self.layer_type == "Const", "Only Const node has value."
attr = self.layer.attr['value']
field = getattr(attr, attr.WhichOneof('value'))
return tensor_util.MakeNdarray(field)
class TFGraph(Graph): class TFGraph(Graph):
def __init__(self, model): def __init__(self, model):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册