提交 5bc026f0 编写于 作者: S SunAhong1993

update the program.py

上级 d8bb8920
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
def gen_codes(code_list, indent=0):
indent_blank = " " * indent
codes = []
for code_line in code_list:
if code_line.strip() == "":
codes.append('\n')
else:
codes.append(indent_blank + code_line + '\n')
return codes
if layer.kernel == "prim.if":
line = "if {} :".format(list(layer.inputs.values())[0])
forward_func.extend(gen_codes([line], indent=indent))
block = layer.blocks[0]
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines)
block = layer.blocks[1]
if len(block.layers) > 0:
line = "else:"
forward_func.extend(gen_codes([line], indent=indent))
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines)
return
elif layer.kernel == "prim.loop":
loop_range = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
loop_range = str(layer.attrs[list(layer.inputs.keys())[0]])
line = "for {} in range({}):".format(layer.outputs[1], loop_range)
forward_func.extend(gen_codes([line], indent=indent))
block = layer.blocks[0]
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines)
return
elif layer.kernel == "prim.equal":
line = "{} = {}".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.constant":
line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
elif layer.kernel == "prim.list":
inputs_list = list(layer.inputs.values())
for i, input in enumerate(inputs_list):
if input is None:
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[
i]])
inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str)
elif layer.kernel == "prim.exception":
exception = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
exception = str(layer.attrs[list(layer.inputs.keys())[0]])
line = "raise RaiseException({})".format(exception)
elif layer.kernel == "prim.min":
line = "{} = min({})".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.add":
line = "{} = {} + {} * {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
layer.attrs["alpha"],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.append":
line = "{} = {}.append({})".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.shape":
line = "{} = {}.shape".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.len":
line = "{} = len({})".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.eq":
line = "{} = {} == {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.assert":
if layer.attrs["type"] == "eq":
if isinstance(layer.attrs["value"], list):
s = ""
for v in layer.attrs["value"]:
s += "{} == {} or ".format(layer.attrs["key"], v)
if len(s) > 0:
s = s[:-4]
line = "assert {}, \'The {} must be {}!\'".format(
s, layer.attrs["key"], layer.attrs["value"])
else:
line = "assert {} == {}, \'The {} must be {}!\'".format(
layer.attrs["key"], layer.attrs["value"],
layer.attrs["key"], layer.attrs["value"])
else:
raise Exception("Not implement yet!")
elif layer.kernel == "prim.getitem":
item0 = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
item0 = str(layer.attrs[list(layer.inputs.keys())[0]])
item1 = list(layer.inputs.values())[1]
if list(layer.inputs.values())[1] is None:
item1 = str(layer.attrs[list(layer.inputs.keys())[1]])
line = "{} = {}[{}]".format(layer.outputs[0], item0, item1)
elif layer.kernel == "prim.le":
item0 = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
item0 = str(layer.attrs[list(layer.inputs.keys())[0]])
item1 = list(layer.inputs.values())[1]
if list(layer.inputs.values())[1] is None:
item1 = str(layer.attrs[list(layer.inputs.keys())[1]])
line = "{} = {} < {}".format(layer.outputs[0], item0, item1)
elif layer.kernel == "prim.slice":
attrs_str = ""
for k, v in layer.attrs.items():
attrs_str += "{}:".format(v)
attrs_str = attrs_str[:-1]
line = "{} = {}[{}]".format(layer.outputs[0],
list(layer.inputs.values())[0],
attrs_str)
forward_func.extend(gen_codes([line], indent=indent))
...@@ -59,7 +59,7 @@ class PaddleLayer(object): ...@@ -59,7 +59,7 @@ class PaddleLayer(object):
class PaddleGraph(object): class PaddleGraph(object):
def __init__(self, father_layer=None): def __init__(self, father_layer=None, graph_type="dygraph"):
self.layers = OrderedDict() self.layers = OrderedDict()
self.edges_out = dict() self.edges_out = dict()
self.edges_in = dict() self.edges_in = dict()
...@@ -67,6 +67,7 @@ class PaddleGraph(object): ...@@ -67,6 +67,7 @@ class PaddleGraph(object):
self.outputs = list() self.outputs = list()
self.parameters = dict() self.parameters = dict()
self.father_layer = father_layer self.father_layer = father_layer
self.graph_type = graph_type
def set_name(self, name): def set_name(self, name):
self.name = name self.name = name
...@@ -129,6 +130,10 @@ class PaddleGraph(object): ...@@ -129,6 +130,10 @@ class PaddleGraph(object):
for block in layer.blocks: for block in layer.blocks:
block.build(layer.inputs, layer.outputs) block.build(layer.inputs, layer.outputs)
if self.graph_type == "dygraph":
self.get_dygraph_inputs()
self.get_dygraph_outputs()
def get_global_layers(self): def get_global_layers(self):
# 该全局layers的信息是按住奥拓扑排序组成的 # 该全局layers的信息是按住奥拓扑排序组成的
def update(layers): def update(layers):
...@@ -265,130 +270,8 @@ class PaddleGraph(object): ...@@ -265,130 +270,8 @@ class PaddleGraph(object):
param.tofile(fp) param.tofile(fp)
fp.close() fp.close()
def convert_prim(self, layer, indent=1): def get_dygraph_inputs(self):
def gen_lines(code_list, indent=0): def update(layers):
indent_blank = " " * indent
lines = []
for code_line in code_list:
if code_line.strip() == "":
lines.append('\n')
else:
lines.append(indent_blank + code_line + '\n')
return lines
if layer.kernel == "prim.if":
line = "if {} :".format(list(layer.inputs.values())[0])
self.forward_lines.extend(gen_lines([line], indent=indent))
block = layer.blocks[0]
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
self.init_lines.extend(b_init_lines)
self.forward_lines.extend(b_forward_lines)
block = layer.blocks[1]
if len(block.layers) > 0:
line = "else:"
self.forward_lines.extend(gen_lines([line], indent=indent))
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
self.init_lines.extend(b_init_lines)
self.forward_lines.extend(b_forward_lines)
return
elif layer.kernel == "prim.loop":
loop_range = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
loop_range = str(layer.attrs[list(layer.inputs.keys())[0]])
line = "for {} in range({}):".format(layer.outputs[1], loop_range)
self.forward_lines.extend(gen_lines([line], indent=indent))
block = layer.blocks[0]
b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
self.init_lines.extend(b_init_lines)
self.forward_lines.extend(b_forward_lines)
return
elif layer.kernel == "prim.equal":
line = "{} = {}".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.constant":
line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
elif layer.kernel == "prim.list":
inputs_list = list(layer.inputs.values())
for i, input in enumerate(inputs_list):
if input is None:
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[
i]])
inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str)
elif layer.kernel == "prim.exception":
exception = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
exception = str(layer.attrs[list(layer.inputs.keys())[0]])
line = "raise RaiseException({})".format(exception)
elif layer.kernel == "prim.min":
line = "{} = min({})".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.add":
line = "{} = {} + {} * {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
layer.attrs["alpha"],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.append":
line = "{} = {}.append({})".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.shape":
line = "{} = {}.shape".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.len":
line = "{} = len({})".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.eq":
line = "{} = {} == {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.assert":
if layer.attrs["type"] == "eq":
if isinstance(layer.attrs["value"], list):
s = ""
for v in layer.attrs["value"]:
s += "{} == {} or ".format(layer.attrs["key"], v)
if len(s) > 0:
s = s[:-4]
line = "assert {}, \'The {} must be {}!\'".format(
s, layer.attrs["key"], layer.attrs["value"])
else:
line = "assert {} == {}, \'The {} must be {}!\'".format(
layer.attrs["key"], layer.attrs["value"],
layer.attrs["key"], layer.attrs["value"])
else:
raise Exception("Not implement yet!")
elif layer.kernel == "prim.getitem":
item0 = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
item0 = str(layer.attrs[list(layer.inputs.keys())[0]])
item1 = list(layer.inputs.values())[1]
if list(layer.inputs.values())[1] is None:
item1 = str(layer.attrs[list(layer.inputs.keys())[1]])
line = "{} = {}[{}]".format(layer.outputs[0], item0, item1)
elif layer.kernel == "prim.le":
item0 = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
item0 = str(layer.attrs[list(layer.inputs.keys())[0]])
item1 = list(layer.inputs.values())[1]
if list(layer.inputs.values())[1] is None:
item1 = str(layer.attrs[list(layer.inputs.keys())[1]])
line = "{} = {} < {}".format(layer.outputs[0], item0, item1)
elif layer.kernel == "prim.slice":
attrs_str = ""
for k, v in layer.attrs.items():
attrs_str += "{}:".format(v)
attrs_str = attrs_str[:-1]
line = "{} = {}[{}]".format(layer.outputs[0],
list(layer.inputs.values())[0],
attrs_str)
self.forward_lines.extend(gen_lines([line], indent=indent))
return
def get_dygraph_inputs(self, layers):
for layer_id, layer in layers.items(): for layer_id, layer in layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0: layer_id, 0) == 0:
...@@ -399,11 +282,13 @@ class PaddleGraph(object): ...@@ -399,11 +282,13 @@ class PaddleGraph(object):
self.inputs.append(value) self.inputs.append(value)
if len(layer.blocks) > 0: if len(layer.blocks) > 0:
for block in layer.blocks: for block in layer.blocks:
block.get_dygraph_inputs(block.layers) block.get_dygraph_inputs()
self.inputs.extend(block.inputs) self.inputs.extend(block.inputs)
update(self.layers)
self.inputs = list(set(self.inputs))
def get_dygraph_outputs(self, layers): def get_dygraph_outputs(self):
for layer_id, layer in layers.items(): for layer_id, layer in self.layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0: layer_id, 0) == 0:
continue continue
...@@ -413,24 +298,21 @@ class PaddleGraph(object): ...@@ -413,24 +298,21 @@ class PaddleGraph(object):
"_assert") or not output_name.startswith("x"): "_assert") or not output_name.startswith("x"):
continue continue
self.outputs.append(output_name) self.outputs.append(output_name)
self.outputs = list(set(self.outputs))
def gen_dygraph_code(self, code_dir=None, indent=2): def gen_dygraph_code(self, code_dir=None, indent=2):
def gen_lines(code_list, indent=0): def gen_codes(code_list, indent=0):
indent_blank = " " * indent indent_blank = " " * indent
lines = [] codes = []
for code_line in code_list: for code_line in code_list:
if code_line.strip() == "": if code_line.strip() == "":
lines.append('\n') codes.append('\n')
else: else:
lines.append(indent_blank + code_line + '\n') codes.append(indent_blank + code_line + '\n')
return lines return codes
self.init_lines = [] def gen_head():
# forward_func self.head = gen_codes(
self.forward_lines = []
# def gen_head
if indent == 2 and code_dir is not None:
start_lines = gen_lines(
[ [
"from paddle.fluid.initializer import Constant", "from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr", "from paddle.fluid.param_attr import ParamAttr",
...@@ -439,19 +321,40 @@ class PaddleGraph(object): ...@@ -439,19 +321,40 @@ class PaddleGraph(object):
"class {}(fluid.dygraph.Layer):".format(self.name), "class {}(fluid.dygraph.Layer):".format(self.name),
], ],
indent=0) indent=0)
self.get_dygraph_inputs(self.layers)
input_data_name = ', '.join(self.inputs) input_data_name = ', '.join(self.inputs)
self.init_lines.extend( self.init_func.extend(
gen_lines( gen_codes(
["def __init__(self, params):"], indent=1)) ["def __init__(self, params):"], indent=1))
self.init_lines.extend( self.init_func.extend(
gen_lines( gen_codes(
["super({}, self).__init__()".format(self.name)], indent=2)) ["super({}, self).__init__()".format(self.name)], indent=2))
self.forward_lines.extend( self.forward_func.extend(
gen_lines( gen_codes(
["def forward(self, {}):".format(input_data_name)], ["def forward(self, {}):".format(input_data_name)],
indent=1)) indent=1))
def write_code(code_dir):
f = open(os.path.join(code_dir, 'code.py'), 'w')
for code_line in self.head:
f.write(code_line)
init_writen_codes = []
for code_line in self.init_func:
if code_line in init_writen_codes:
continue
f.write(code_line)
init_writen_codes.append(code_line)
f.write("\n")
return_code = "return {}".format(", ".join(self.outputs))
self.forward_func.extend(gen_codes([return_code], indent=2))
for code_line in self.forward_func:
f.write(code_line)
f.close()
self.init_func = []
self.forward_func = []
if indent == 2 and code_dir is not None:
gen_head()
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0: layer_id, 0) == 0:
...@@ -470,10 +373,10 @@ class PaddleGraph(object): ...@@ -470,10 +373,10 @@ class PaddleGraph(object):
if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[ if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[
"value"].startswith("params["): "value"].startswith("params["):
self.forward_lines.extend(gen_lines([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
continue continue
else: else:
self.init_lines.extend(gen_lines([line], indent=2)) self.init_func.extend(gen_codes([line], indent=2))
if len(layer.outputs) == 1: if len(layer.outputs) == 1:
line = layer.outputs[0] line = layer.outputs[0]
...@@ -490,9 +393,12 @@ class PaddleGraph(object): ...@@ -490,9 +393,12 @@ class PaddleGraph(object):
line += "{}, ".format(v) line += "{}, ".format(v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
self.forward_lines.extend(gen_lines([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel: elif "prim" in layer.kernel:
self.convert_prim(layer, indent=indent) from .convert_prim import convert_prim
convert_prim(layer, indent=indent,
init_func=self.init_func,
forward_func=self.forward_func)
else: else:
if len(layer.outputs) == 1: if len(layer.outputs) == 1:
line = layer.outputs[0] line = layer.outputs[0]
...@@ -505,26 +411,11 @@ class PaddleGraph(object): ...@@ -505,26 +411,11 @@ class PaddleGraph(object):
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
self.forward_lines.extend(gen_lines([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
if indent == 2: if indent == 2:
f = open(os.path.join(code_dir, 'code.py'), 'w') write_code(code_dir)
for line in start_lines:
f.write(line)
init_writen_line = []
for line in self.init_lines:
if line in init_writen_line:
continue
f.write(line)
init_writen_line.append(line)
f.write("\n")
self.get_dygraph_outputs(self.layers)
return_line = "return {}".format(", ".join(self.outputs))
self.forward_lines.extend(gen_lines([return_line], indent=2))
for line in self.forward_lines:
f.write(line)
f.close()
else: else:
return self.init_lines, self.forward_lines return self.init_func, self.forward_func
def dump_dygraph_parameter(self, code_dir): def dump_dygraph_parameter(self, code_dir):
params_output = open(os.path.join(code_dir, 'model.pdparams'), 'wb') params_output = open(os.path.join(code_dir, 'model.pdparams'), 'wb')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册