提交 9b286963 编写于 作者: J jiangjiajun

test codes

上级 4c95a4cb
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Emitter(object):
def __init__(self):
print("Nothing done")
def save_inference_model(self):
print("Not Implement")
def save_python_code(self):
print("Not Implement")
...@@ -21,10 +21,8 @@ class GraphNode(object): ...@@ -21,10 +21,8 @@ class GraphNode(object):
self.outputs = list() self.outputs = list()
self.layer = layer self.layer = layer
if layer_name is not None: assert layer_name is not None, "layer_name for GraphNode should not be None"
self.layer_name = layer_name self.layer_name = layer_name
else:
self.layer_name = layer.name
def __hash__(self): def __hash__(self):
return hash(self.layer.name) return hash(self.layer.name)
...@@ -70,6 +68,8 @@ class Graph(object): ...@@ -70,6 +68,8 @@ class Graph(object):
num_inputs[node.layer_name] -= 1 num_inputs[node.layer_name] -= 1
if num_inputs[node.layer_name] == 0: if num_inputs[node.layer_name] == 0:
self.topo_sort.append(node.layer_name) self.topo_sort.append(node.layer_name)
for i, tmp in enumerate(self.topo_sort):
print(tmp)
def get_node(self, name): def get_node(self, name):
if name not in self.node_map: if name not in self.node_map:
......
...@@ -11,3 +11,7 @@ ...@@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.parser import TFGraph
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO useless node remove
# TODO identity node remove
# TODO subgraph optimize
# TODO compute optimize
...@@ -19,6 +19,9 @@ import copy ...@@ -19,6 +19,9 @@ import copy
class TFGraphNode(GraphNode): class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None): def __init__(self, layer, layer_name=None):
if layer_name is None:
super(TFGraphNode, self).__init__(layer, layer.name)
else:
super(TFGraphNode, self).__init__(layer, layer_name) super(TFGraphNode, self).__init__(layer, layer_name)
self.layer_type = layer.op self.layer_type = layer.op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册