提交 c57ae16c 编写于 作者: S SunAhong1993

replace program with paddle_graph

上级 34d61263
__version__ = "0.7.4"
from .core.program import PaddleProgram
from .core.program import PaddleGraph
program = PaddleProgram()
program = PaddleGraph()
name_counter = dict()
......
......@@ -18,15 +18,15 @@ import paddle.fluid as fluid
from paddle.fluid.proto import framework_pb2
from collections import OrderedDict
import numpy
import time
import collections
import sys
import os
import six
import pickle
class PaddleLayer(object):
def __init__(self, kernel, inputs, outputs, **kwargs):
def __init__(self, id, kernel, inputs, outputs, **kwargs):
assert isinstance(
inputs,
dict), "parameter 'inputs' for PaddleLayer should be type of dict"
......@@ -51,22 +51,28 @@ class PaddleLayer(object):
self.inputs = inputs
self.outputs = outputs
self.attrs = kwargs
self.id = str(time.time())
self.id = id
self.blocks = list()
def add_block(self, block):
block.father_layer = self
self.blocks.append(block)
class PaddleProgram(object):
def __init__(self):
class PaddleGraph(object):
def __init__(self, father_layer=None):
self.layers = OrderedDict()
self.edges_out = dict()
self.edges_in = dict()
self.inputs = list()
self.outputs = list()
self.parameters = dict()
self.father_layer = None
self.father_layer = father_layer
def set_name(self, name):
self.name = name
def set_parameters(self, parameters):
self.parameters = parameters
def clear(self):
self.layers = OrderedDict()
......@@ -76,15 +82,22 @@ class PaddleProgram(object):
self.outputs = list()
self.parameters = dict()
def clear_edges(self):
self.edges_out = dict()
self.edges_in = dict()
def add_layer(self, kernel, inputs, outputs, **kwargs):
layer = PaddleLayer(kernel, inputs, outputs, **kwargs)
layer_id = str(len(self.layers))
if self.father_layer is not None:
layer_id = "{}.{}.{}".format(layer_id, len(self.father_layer.blocks()), self.father_layer.id)
layer_id = "{}.{}.{}".format(self.father_layer.id,
len(self.father_layer.blocks),
layer_id)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, **kwargs)
self.layers[layer_id] = layer
return layer_id
def build(self):
def build(self, inputs=None, outputs=None):
self.clear_edges()
outputs_from_nodes = dict()
for layer_id, layer in self.layers.items():
for input_key, input_var in layer.inputs.items():
......@@ -92,9 +105,16 @@ class PaddleProgram(object):
if not isinstance(vs, list):
vs = [vs]
for v in vs:
assert v in outputs_from_nodes, "Couldn't find {} in previous layers, the layers should be make by topological sort".format(
assert v in outputs_from_nodes or (
inputs is not None and v in list(inputs.values())
) or (
outputs is not None and v in outputs
), "Couldn't find {} in previous layers, the layers should be make by topological sort".format(
v)
if v in outputs_from_nodes:
in_layer_id = outputs_from_nodes[v]
else:
in_layer_id = -1
if in_layer_id not in self.edges_out:
self.edges_out[in_layer_id] = list()
self.edges_out[in_layer_id].append(layer_id)
......@@ -105,6 +125,23 @@ class PaddleProgram(object):
for output in layer.outputs:
outputs_from_nodes[output] = layer_id
if len(layer.blocks) > 0:
for block in layer.blocks:
block.build(layer.inputs, layer.outputs)
def get_global_layers(self):
# 该全局layers的信息是按住奥拓扑排序组成的
def update(layers):
global_layers = dict()
for layer_id, layer in layers.items():
global_layers[layer_id] = layer
for block in layer.blocks:
block_global_layers = update(block.layers)
global_layers.update(block_global_layers)
return global_layers
return update(self.layers)
def gen_code(self, code_dir):
def write_code(f, code_list, indent=0):
indent_blank = " " * indent
......@@ -227,3 +264,269 @@ class PaddleProgram(object):
fp.write(tensor_desc.SerializeToString())
param.tofile(fp)
fp.close()
def convert_prim(self, layer, indent=1):
def gen_lines(code_list, indent=0):
indent_blank = " " * indent
lines = []
for code_line in code_list:
if code_line.strip() == "":
lines.append('\n')
else:
lines.append(indent_blank + code_line + '\n')
return lines
if layer.kernel == "prim.if":
line = "if {} :".format(list(layer.inputs.values())[0])
self.forward_lines.extend(gen_lines([line], indent=indent))
block = layer.blocks[0]
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
self.init_lines.extend(b_init_lines)
self.forward_lines.extend(b_forward_lines)
block = layer.blocks[1]
if len(block.layers) > 0:
line = "else:"
self.forward_lines.extend(gen_lines([line], indent=indent))
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
self.init_lines.extend(b_init_lines)
self.forward_lines.extend(b_forward_lines)
return
elif layer.kernel == "prim.loop":
loop_range = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
loop_range = str(layer.attrs[list(layer.inputs.keys())[0]])
line = "for {} in range({}):".format(layer.outputs[1], loop_range)
self.forward_lines.extend(gen_lines([line], indent=indent))
block = layer.blocks[0]
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
self.init_lines.extend(b_init_lines)
self.forward_lines.extend(b_forward_lines)
return
elif layer.kernel == "prim.equal":
line = "{} = {}".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.constant":
line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
elif layer.kernel == "prim.list":
inputs_list = list(layer.inputs.values())
for i, input in enumerate(inputs_list):
if input is None:
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[
i]])
inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str)
elif layer.kernel == "prim.exception":
exception = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
exception = str(layer.attrs[list(layer.inputs.keys())[0]])
line = "raise RaiseException({})".format(exception)
elif layer.kernel == "prim.min":
line = "{} = min({})".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.add":
line = "{} = {} + {} * {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
layer.attrs["alpha"],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.append":
line = "{} = {}.append({})".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.shape":
line = "{} = {}.shape".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.len":
line = "{} = len({})".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.eq":
line = "{} = {} == {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.assert":
if layer.attrs["type"] == "eq":
if isinstance(layer.attrs["value"], list):
s = ""
for v in layer.attrs["value"]:
s += "{} == {} or ".format(layer.attrs["key"], v)
if len(s) > 0:
s = s[:-4]
line = "assert {}, \'The {} must be {}!\'".format(
s, layer.attrs["key"], layer.attrs["value"])
else:
line = "assert {} == {}, \'The {} must be {}!\'".format(
layer.attrs["key"], layer.attrs["value"],
layer.attrs["key"], layer.attrs["value"])
else:
raise Exception("Not implement yet!")
elif layer.kernel == "prim.getitem":
item0 = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
item0 = str(layer.attrs[list(layer.inputs.keys())[0]])
item1 = list(layer.inputs.values())[1]
if list(layer.inputs.values())[1] is None:
item1 = str(layer.attrs[list(layer.inputs.keys())[1]])
line = "{} = {}[{}]".format(layer.outputs[0], item0, item1)
elif layer.kernel == "prim.le":
item0 = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
item0 = str(layer.attrs[list(layer.inputs.keys())[0]])
item1 = list(layer.inputs.values())[1]
if list(layer.inputs.values())[1] is None:
item1 = str(layer.attrs[list(layer.inputs.keys())[1]])
line = "{} = {} < {}".format(layer.outputs[0], item0, item1)
elif layer.kernel == "prim.slice":
attrs_str = ""
for k, v in layer.attrs.items():
attrs_str += "{}:".format(v)
attrs_str = attrs_str[:-1]
line = "{} = {}[{}]".format(layer.outputs[0],
list(layer.inputs.values())[0],
attrs_str)
self.forward_lines.extend(gen_lines([line], indent=indent))
return
def get_dygraph_inputs(self, layers):
for layer_id, layer in layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0:
continue
if layer.kernel == "fluid.dygraph.base.to_variable":
value = layer.attrs["value"]
if not value.startswith("params["):
self.inputs.append(value)
if len(layer.blocks) > 0:
for block in layer.blocks:
block.get_dygraph_inputs(block.layers)
self.inputs.extend(block.inputs)
def get_dygraph_outputs(self, layers):
for layer_id, layer in layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0:
continue
if self.edges_out.get(layer_id, 0) == 0:
for output_name in layer.outputs:
if output_name.endswith(
"_assert") or not output_name.startswith("x"):
continue
self.outputs.append(output_name)
def gen_dygraph_code(self, code_dir=None, indent=2):
def gen_lines(code_list, indent=0):
indent_blank = " " * indent
lines = []
for code_line in code_list:
if code_line.strip() == "":
lines.append('\n')
else:
lines.append(indent_blank + code_line + '\n')
return lines
self.init_lines = []
# forward_func
self.forward_lines = []
# def gen_head
if indent == 2 and code_dir is not None:
start_lines = gen_lines(
[
"from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr",
"import paddle.fluid as fluid",
"",
"class {}(fluid.dygraph.Layer):".format(self.name),
],
indent=0)
self.get_dygraph_inputs(self.layers)
input_data_name = ', '.join(self.inputs)
self.init_lines.extend(
gen_lines(
["def __init__(self, params):"], indent=1))
self.init_lines.extend(
gen_lines(
["super({}, self).__init__()".format(self.name)], indent=2))
self.forward_lines.extend(
gen_lines(
["def forward(self, {}):".format(input_data_name)],
indent=1))
for layer_id, layer in self.layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0:
continue
if "dygraph" in layer.kernel:
line = "{}".format(
layer.outputs[0]
) if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[
"value"].startswith("params[") else "self.{}".format(
layer.outputs[0])
line += " = {}(".format(layer.kernel)
for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v)
line = line.strip(", ")
line += ")"
if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[
"value"].startswith("params["):
self.forward_lines.extend(gen_lines([line], indent=indent))
continue
else:
self.init_lines.extend(gen_lines([line], indent=2))
if len(layer.outputs) == 1:
line = layer.outputs[0]
elif len(layer.outputs) == 2:
line = layer.outputs[1]
else:
line = ','.join(layer.outputs[1:])
if layer.kernel == "fluid.dygraph.base.to_variable" and layer.attrs[
"value"].startswith("params["):
line += " = self.{}".format(layer.outputs[0])
else:
line += " = self.{}(".format(layer.outputs[0])
for k, v in layer.inputs.items():
line += "{}, ".format(v)
line = line.strip(", ")
line += ")"
self.forward_lines.extend(gen_lines([line], indent=indent))
elif "prim" in layer.kernel:
self.convert_prim(layer, indent=indent)
else:
if len(layer.outputs) == 1:
line = layer.outputs[0]
else:
line = ','.join(layer.outputs)
line += " = {}(".format(layer.kernel)
for k, v in layer.inputs.items():
line += "{}={}, ".format(k, v)
for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v)
line = line.strip(", ")
line += ")"
self.forward_lines.extend(gen_lines([line], indent=indent))
if indent == 2:
f = open(os.path.join(code_dir, 'code.py'), 'w')
for line in start_lines:
f.write(line)
init_writen_line = []
for line in self.init_lines:
if line in init_writen_line:
continue
f.write(line)
init_writen_line.append(line)
f.write("\n")
self.get_dygraph_outputs(self.layers)
return_line = "return {}".format(", ".join(self.outputs))
self.forward_lines.extend(gen_lines([return_line], indent=2))
for line in self.forward_lines:
f.write(line)
f.close()
else:
return self.init_lines, self.forward_lines
def dump_dygraph_parameter(self, code_dir):
params_output = open(os.path.join(code_dir, 'model.pdparams'), 'wb')
pickle.dump(self.parameters, params_output)
params_output.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册