提交 142f9bbf 编写于 作者: J jiangjiajun

test code

上级 9b286963
......@@ -17,18 +17,46 @@ class Layer(object):
def __init__(self):
self.op = None
self.param_attr = dict()
self.input = None
self.inputs = dict()
self.output = None
self.str_code = None
def get_code(self):
if self.str_code is not None:
return self.str_code
layer_code = ""
if self.output is not None:
layer_code = self.output + " = "
layer_code = layer_code + "fluid.layers." + self.op + "("
for key, tensor in self.inputs.items():
layer_code = layer_code + key + "=" + tensor + ", "
for key, value in self.param_attr.items():
layer_code = layer_code + key + "=" + value + ", "
layer_code = layer_code.strip(", ")
return layer_code += ")"
class FluidCode(object):
def __init__(self):
self.codes = list()
self.layers = list()
def add_layer(self, op, input, output, param_attr=None):
def add_layer(self, op, inputs, output, param_attr=None):
layer = Layer()
layer.op = op
layer.inputs = inputs
layer.output = output
if param_attr is not None:
layer.param_attr = param_attr
self.layers.append(layer)
def add_note(self, note):
# note should be string
self.layers.append(note)
def gen_codes(self):
codes = list()
for layer in self.layers:
if isinstance(layer, Layer):
codes.append(layer.get_code())
elif isinstance(layer, str):
codes.append(layer)
......@@ -62,14 +62,16 @@ class Graph(object):
num_inputs[name] = len(node.inputs)
self.topo_sort = self.input_nodes[:]
for idx in range(len(self.topo_sort)):
while idx < len(self.topo_sort):
current_node = self.node_map[self.topo_sort[idx]]
for node in current_node.outputs:
num_inputs[node.layer_name] -= 1
if num_inputs[node.layer_name] == 0:
self.topo_sort.append(node.layer_name)
idx += 1
for i, tmp in enumerate(self.topo_sort):
print(tmp)
print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs)
def get_node(self, name):
if name not in self.node_map:
......
......@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.parser import TFGraph
from x2paddle.parser.tf_parser import TFGraph
from x2paddle.core.emitter import Emitter
class TFEmitter(Emitter):
def __init__(self):
super(TFEmitter, self
......@@ -13,6 +13,18 @@
# limitations under the License.
# TODO useless node remove
from x2paddle.parser.tf_parser import TFGraph
class TFGraphOptimizer(object):
def __init__(self):
print("Not Implement")
self.useless_op = [
'NoOp']
def remove_useless_node(self, graph):
for name, node in graph.node_map.items():
if node.layer_type in self.useless_op:
# TODO identity node remove
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册