提交 3cf5a607 编写于 作者: J jiangjiajun

{add codes

上级 0e572fc5
...@@ -32,8 +32,35 @@ class Layer(object): ...@@ -32,8 +32,35 @@ class Layer(object):
layer_code = layer_code + "fluid.layers." + self.op + "(" layer_code = layer_code + "fluid.layers." + self.op + "("
for key, tensor in self.inputs.items(): if isinstance(self.inputs, list):
layer_code = layer_code + key + "={}, ".format(tensor) in_list = "["
for input in self.inputs:
assert isinstance(
input, GraphNode), "Type of input should be GraphNode"
if hasattr(input, "index"):
in_list += (input.layer_name + "[{}]".format(input.index) +
", ")
else:
in_list += (input.layer_name + ", ")
inlist = in_list.strip(", ") + "], "
elif isinstance(self.inputs, dict):
for key, input in self.inputs.items():
assert isinstance(
input, GraphNode), "Type of input should be GraphNode"
if hasattr(input, "index"):
layer_code = layer_code + key + "={}, ".format(
input.layer_name + "[{}]".format(input.index))
else:
layer_code = layer_code + key + "={}, ".format(
input.layer_name)
elif isinstance(self.inputs, GraphNode):
if hasattr(self.inputs, "index"):
layer_code += (self.inputs.layer_name +
"[{}]".format(self.inputs.index) + ", ")
else:
layer_code += (self.inputs.layer_name + ", ")
else:
raise Exception("Unknown type of inputs.")
for key, value in self.param_attr.items(): for key, value in self.param_attr.items():
layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code + key + "={}, ".format(value)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import collections import collections
from copy import deepcopy
class GraphNode(object): class GraphNode(object):
...@@ -72,16 +73,24 @@ class Graph(object): ...@@ -72,16 +73,24 @@ class Graph(object):
self.topo_sort.append(node) self.topo_sort.append(node)
idx += 1 idx += 1
def get_node(self, name): def get_node(self, name, copy=False):
if name not in self.node_map: if name not in self.node_map:
if name.split(':')[0] in self.node_map: if name.split(':')[0] in self.node_map:
name_prefix, idx = name.split(':') name_prefix, idx = name.split(':')
self.node_map[name_prefix].index = int(idx) if copy:
return self.node_map[name_prefix] node = deepcopy(self.node_map[name_prefix])
else:
node = self.node_map[name_prefix]
node.index = int(idx)
return node
else: else:
raise Exception("Graph doesn't have node [%s]." % name) raise Exception("Graph doesn't have node [%s]." % name)
else: else:
return self.node_map[name] if copy:
node = deepcopy(self.node_map[name])
else:
node = self.node_map[name]
return node
def connect(self, src, dst): def connect(self, src, dst):
if dst not in self.node_map: if dst not in self.node_map:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册