提交 fb4294b1 编写于 作者: S SunAhong1993

modify the program

上级 a4056d9c
...@@ -204,6 +204,7 @@ class PaddleGraph(object): ...@@ -204,6 +204,7 @@ class PaddleGraph(object):
f.close() f.close()
def gen_model(self, save_dir): def gen_model(self, save_dir):
if self.graph_type == "static":
code_dir = os.path.join(save_dir, 'model_with_code') code_dir = os.path.join(save_dir, 'model_with_code')
infer_dir = os.path.join(save_dir, 'inference_model') infer_dir = os.path.join(save_dir, 'inference_model')
self.gen_code(code_dir) self.gen_code(code_dir)
...@@ -235,6 +236,9 @@ class PaddleGraph(object): ...@@ -235,6 +236,9 @@ class PaddleGraph(object):
feeded_var_names=[i.name for i in inputs], feeded_var_names=[i.name for i in inputs],
target_vars=outputs, target_vars=outputs,
executor=exe) executor=exe)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
def dump_parameter(self, param_name, param, save_dir): def dump_parameter(self, param_name, param, save_dir):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
...@@ -356,9 +360,11 @@ class PaddleGraph(object): ...@@ -356,9 +360,11 @@ class PaddleGraph(object):
gen_head() gen_head()
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if len(self.layers) > 1:
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0 and layer.kernel != "prim.assert" \ layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception": and layer.kernel != "prim.exception" \
and layer.kernel != "prim.warnings":
continue continue
if "dygraph" in layer.kernel: if "dygraph" in layer.kernel:
line = "{}".format( line = "{}".format(
......
...@@ -138,7 +138,7 @@ def prim_Loop(mapper, graph, node): ...@@ -138,7 +138,7 @@ def prim_Loop(mapper, graph, node):
node_outputs = mapper._get_outputs_name(node) node_outputs = mapper._get_outputs_name(node)
loop_inputs = {} loop_inputs = {}
block = list(node.blocks())[0] block = list(node.blocks())[0]
loop_outputs = node_outputs loop_outputs = node_outputs.copy()
for i, block_input_ivalue in enumerate(block.inputs()): for i, block_input_ivalue in enumerate(block.inputs()):
if i == 0: if i == 0:
block_input_node_name = '_x' + str(mapper.output_index) block_input_node_name = '_x' + str(mapper.output_index)
......
...@@ -179,7 +179,7 @@ def prim_mul(layer, indent=1, init_func=[], forward_func=[]): ...@@ -179,7 +179,7 @@ def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
def prim_ne(layer, indent=1, init_func=[], forward_func=[]): def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} < {}".format(layer.outputs[0], line = "{} = {} != {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y")) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
......
...@@ -14,5 +14,9 @@ ...@@ -14,5 +14,9 @@
from .fc_fuser import FcFuser from .fc_fuser import FcFuser
from .fc_fuse_pass import FcFusePass from .fc_fuse_pass import FcFusePass
from .nn_adaptive_pool2d_fuser import NnAdaptivePool2dFuser
from .nn_adaptive_pool2d_fuse_pass import NnAdaptivePool2dFusePass
from .functional_adaptive_pool2d_fuser import FunctionalAdaptivePool2dFuser
from .functional_adaptive_pool2d_fuse_pass import FunctionalAdaptivePool2dFusePass
from .constant_fuser import ConstantFuser from .constant_fuser import ConstantFuser
from .constant_fuse_pass import ConstantFusePass from .constant_fuse_pass import ConstantFusePass
...@@ -21,7 +21,7 @@ from x2paddle.core.util import * ...@@ -21,7 +21,7 @@ from x2paddle.core.util import *
class FcFuser(FuseBase): class FcFuser(FuseBase):
def __init__(self): def __init__(self):
self.linear_index = 0 self.linear_index = 0
super(FcFuser, self).__init__() super(FcFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的fc图结构。 """ 描述需要替换的fc图结构。
...@@ -70,7 +70,7 @@ class FcFuser(FuseBase): ...@@ -70,7 +70,7 @@ class FcFuser(FuseBase):
self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)]) self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)])
self.pattern.outputs.append(gen_name(4)) self.pattern.outputs.append(gen_name(4))
if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer1) pattern_block0 = PaddleGraph(if_layer1, graph_type="dygraph")
pattern_block0.add_layer( pattern_block0.add_layer(
"fluid.dygraph.base.to_variable", "fluid.dygraph.base.to_variable",
inputs={}, inputs={},
...@@ -99,7 +99,7 @@ class FcFuser(FuseBase): ...@@ -99,7 +99,7 @@ class FcFuser(FuseBase):
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.equal", inputs={'input': gen_name(8)}, outputs=[gen_name(4)]) "prim.equal", inputs={'input': gen_name(8)}, outputs=[gen_name(4)])
if_layer1.add_block(pattern_block0) if_layer1.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer1) pattern_block1 = PaddleGraph(if_layer1, graph_type="dygraph")
pattern_block1.add_layer( pattern_block1.add_layer(
"fluid.dygraph.base.to_variable", "fluid.dygraph.base.to_variable",
inputs={}, inputs={},
...@@ -122,7 +122,7 @@ class FcFuser(FuseBase): ...@@ -122,7 +122,7 @@ class FcFuser(FuseBase):
[gen_name(11)]) [gen_name(11)])
if_layer2 = pattern_block1.layers[list(pattern_block1.layers.keys())[ if_layer2 = pattern_block1.layers[list(pattern_block1.layers.keys())[
-1]] -1]]
pattern_block1_block0 = PaddleGraph(if_layer2) pattern_block1_block0 = PaddleGraph(if_layer2, graph_type="dygraph")
pattern_block1_block0.add_layer( pattern_block1_block0.add_layer(
"fluid.dygraph.base.to_variable", "fluid.dygraph.base.to_variable",
inputs={}, inputs={},
...@@ -140,7 +140,7 @@ class FcFuser(FuseBase): ...@@ -140,7 +140,7 @@ class FcFuser(FuseBase):
inputs={'input': gen_name(13)}, inputs={'input': gen_name(13)},
outputs=[gen_name(11)]) outputs=[gen_name(11)])
if_layer2.add_block(pattern_block1_block0) if_layer2.add_block(pattern_block1_block0)
pattern_block1_block1 = PaddleGraph(if_layer2) pattern_block1_block1 = PaddleGraph(if_layer2, graph_type="dygraph")
pattern_block1_block1.add_layer( pattern_block1_block1.add_layer(
"prim.equal", inputs={'input': gen_name(9)}, "prim.equal", inputs={'input': gen_name(9)},
outputs=[gen_name(11)]) outputs=[gen_name(11)])
...@@ -150,12 +150,9 @@ class FcFuser(FuseBase): ...@@ -150,12 +150,9 @@ class FcFuser(FuseBase):
outputs=[gen_name(4)]) outputs=[gen_name(4)])
if_layer2.add_block(pattern_block1_block1) if_layer2.add_block(pattern_block1_block1)
if_layer1.add_block(pattern_block1) if_layer1.add_block(pattern_block1)
self.pattern.build( self.pattern.build(inputs={"input-0": "fc-input-0"})
inputs={"input-0": "fc-input-0",
"input-1": "fc-input-0"})
def insert_new_layer(self, graph, matches): def insert_new_layer(self, graph, parameters, matches):
parameters = graph.parameters
new_layer = self.gen_new_layer(parameters, matches) new_layer = self.gen_new_layer(parameters, matches)
new_layer_id = list(matches.keys())[0] new_layer_id = list(matches.keys())[0]
graph.layers[new_layer_id] = new_layer graph.layers[new_layer_id] = new_layer
...@@ -171,7 +168,7 @@ class FcFuser(FuseBase): ...@@ -171,7 +168,7 @@ class FcFuser(FuseBase):
weight_name = layer.attrs["value"][8:-2] weight_name = layer.attrs["value"][8:-2]
layer = matches[layers_id[8]] layer = matches[layers_id[8]]
bias_name = layer.attrs["value"][8:-2] bias_name = layer.attrs["value"][8:-2]
attrs = {} attrs = dict()
attrs["input_dim"] = parameters[weight_name].shape[1] attrs["input_dim"] = parameters[weight_name].shape[1]
attrs["output_dim"] = parameters[weight_name].shape[0] attrs["output_dim"] = parameters[weight_name].shape[0]
linear_name = "linear{}".format(self.linear_index) linear_name = "linear{}".format(self.linear_index)
......
...@@ -18,7 +18,10 @@ from x2paddle.optimizer.pass_manager import PassManager ...@@ -18,7 +18,10 @@ from x2paddle.optimizer.pass_manager import PassManager
class GraphOptimizer(object): class GraphOptimizer(object):
def __init__(self): def __init__(self):
self.passes = ["fc_fuse_pass", "constant_fuse_pass"] self.passes = [
"fc_fuse_pass", "nn_adaptive_pool2d_fuse_pass",
"functional_adaptive_pool2d_fuse_pass", "constant_fuse_pass"
]
def optimize(self, graph): def optimize(self, graph):
for pass_name in self.passes: for pass_name in self.passes:
......
...@@ -83,16 +83,13 @@ class PatternMatcher(object): ...@@ -83,16 +83,13 @@ class PatternMatcher(object):
# 若pattern当前layer的输出是pattern的输出,则是正确的 # 若pattern当前layer的输出是pattern的输出,则是正确的
return False return False
# 当为控制流时的处理 # 当为控制流时的处理
if layer.kernel == "prim.if": if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
match_info = get_subgraph(pattern_layer.blocks[0], if len(pattern_layer.blocks) != len(layer.blocks):
layer.blocks[0], 0)
if match_info:
subgraph_id2layers.update(match_info)
else:
return False return False
match_info = get_subgraph(pattern_layer.blocks[1], for i, b in enumerate(pattern_layer.blocks):
layer.blocks[1], 0) match_info = get_subgraph(pattern_layer.blocks[i],
if match_info: layer.blocks[i], 0)
if match_info is not False:
subgraph_id2layers.update(match_info) subgraph_id2layers.update(match_info)
else: else:
return False return False
...@@ -101,6 +98,7 @@ class PatternMatcher(object): ...@@ -101,6 +98,7 @@ class PatternMatcher(object):
return subgraph_id2layers return subgraph_id2layers
else: else:
return False return False
return subgraph_id2layers
for i, (layer_id, layer) in enumerate(graph.layers.items()): for i, (layer_id, layer) in enumerate(graph.layers.items()):
match_info = get_subgraph(self.pattern, graph, i) match_info = get_subgraph(self.pattern, graph, i)
...@@ -152,16 +150,17 @@ def get_subgraph(prefix_layer_id, suffix_layer_id, graph): ...@@ -152,16 +150,17 @@ def get_subgraph(prefix_layer_id, suffix_layer_id, graph):
class FuseBase(object): class FuseBase(object):
def __init__(self): def __init__(self, graph_type):
self.pattern = PaddleGraph() self.pattern = PaddleGraph(graph_type=graph_type)
def operate(self, graph, match_kind="topo"): def operate(self, graph, match_kind="topo"):
parameters = graph.parameters
self.build_pattern() self.build_pattern()
self.perform_pattern_matcher(graph, match_kind) self.perform_pattern_matcher(graph, match_kind)
for match in self.matches: for match in self.matches:
first_layer_id = list(match.keys())[0] first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph) subgraph = get_subgraph("", first_layer_id, graph)
self.insert_new_layer(subgraph, match) self.insert_new_layer(subgraph, parameters, match)
self.delete_inter_layer(graph) self.delete_inter_layer(graph)
graph.build() graph.build()
...@@ -183,6 +182,6 @@ class FuseBase(object): ...@@ -183,6 +182,6 @@ class FuseBase(object):
param_name = layer.attrs["value"][8:-2] param_name = layer.attrs["value"][8:-2]
if param_name in graph.parameters: if param_name in graph.parameters:
graph.parameters.pop(param_name) graph.parameters.pop(param_name)
if layer_id in graph.layers: if layer_id in subgraph.layers:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图 # layer_id可能是属于子图的,此时删除父layer,即删除整个子图
subgraph.layers.pop(layer_id) subgraph.layers.pop(layer_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册