提交 fb4294b1 编写于 作者: S SunAhong1993

modify the program

上级 a4056d9c
......@@ -204,6 +204,7 @@ class PaddleGraph(object):
f.close()
def gen_model(self, save_dir):
if self.graph_type == "static":
code_dir = os.path.join(save_dir, 'model_with_code')
infer_dir = os.path.join(save_dir, 'inference_model')
self.gen_code(code_dir)
......@@ -235,6 +236,9 @@ class PaddleGraph(object):
feeded_var_names=[i.name for i in inputs],
target_vars=outputs,
executor=exe)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
def dump_parameter(self, param_name, param, save_dir):
if not os.path.exists(save_dir):
......@@ -356,9 +360,11 @@ class PaddleGraph(object):
gen_head()
for layer_id, layer in self.layers.items():
if len(self.layers) > 1:
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception":
and layer.kernel != "prim.exception" \
and layer.kernel != "prim.warnings":
continue
if "dygraph" in layer.kernel:
line = "{}".format(
......
......@@ -138,7 +138,7 @@ def prim_Loop(mapper, graph, node):
node_outputs = mapper._get_outputs_name(node)
loop_inputs = {}
block = list(node.blocks())[0]
loop_outputs = node_outputs
loop_outputs = node_outputs.copy()
for i, block_input_ivalue in enumerate(block.inputs()):
if i == 0:
block_input_node_name = '_x' + str(mapper.output_index)
......
......@@ -179,7 +179,7 @@ def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} < {}".format(layer.outputs[0],
line = "{} = {} != {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
......
......@@ -14,5 +14,9 @@
from .fc_fuser import FcFuser
from .fc_fuse_pass import FcFusePass
from .nn_adaptive_pool2d_fuser import NnAdaptivePool2dFuser
from .nn_adaptive_pool2d_fuse_pass import NnAdaptivePool2dFusePass
from .functional_adaptive_pool2d_fuser import FunctionalAdaptivePool2dFuser
from .functional_adaptive_pool2d_fuse_pass import FunctionalAdaptivePool2dFusePass
from .constant_fuser import ConstantFuser
from .constant_fuse_pass import ConstantFusePass
......@@ -21,7 +21,7 @@ from x2paddle.core.util import *
class FcFuser(FuseBase):
def __init__(self):
self.linear_index = 0
super(FcFuser, self).__init__()
super(FcFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的fc图结构。
......@@ -70,7 +70,7 @@ class FcFuser(FuseBase):
self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)])
self.pattern.outputs.append(gen_name(4))
if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer1)
pattern_block0 = PaddleGraph(if_layer1, graph_type="dygraph")
pattern_block0.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
......@@ -99,7 +99,7 @@ class FcFuser(FuseBase):
pattern_block0.add_layer(
"prim.equal", inputs={'input': gen_name(8)}, outputs=[gen_name(4)])
if_layer1.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer1)
pattern_block1 = PaddleGraph(if_layer1, graph_type="dygraph")
pattern_block1.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
......@@ -122,7 +122,7 @@ class FcFuser(FuseBase):
[gen_name(11)])
if_layer2 = pattern_block1.layers[list(pattern_block1.layers.keys())[
-1]]
pattern_block1_block0 = PaddleGraph(if_layer2)
pattern_block1_block0 = PaddleGraph(if_layer2, graph_type="dygraph")
pattern_block1_block0.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
......@@ -140,7 +140,7 @@ class FcFuser(FuseBase):
inputs={'input': gen_name(13)},
outputs=[gen_name(11)])
if_layer2.add_block(pattern_block1_block0)
pattern_block1_block1 = PaddleGraph(if_layer2)
pattern_block1_block1 = PaddleGraph(if_layer2, graph_type="dygraph")
pattern_block1_block1.add_layer(
"prim.equal", inputs={'input': gen_name(9)},
outputs=[gen_name(11)])
......@@ -150,12 +150,9 @@ class FcFuser(FuseBase):
outputs=[gen_name(4)])
if_layer2.add_block(pattern_block1_block1)
if_layer1.add_block(pattern_block1)
self.pattern.build(
inputs={"input-0": "fc-input-0",
"input-1": "fc-input-0"})
self.pattern.build(inputs={"input-0": "fc-input-0"})
def insert_new_layer(self, graph, matches):
parameters = graph.parameters
def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches)
new_layer_id = list(matches.keys())[0]
graph.layers[new_layer_id] = new_layer
......@@ -171,7 +168,7 @@ class FcFuser(FuseBase):
weight_name = layer.attrs["value"][8:-2]
layer = matches[layers_id[8]]
bias_name = layer.attrs["value"][8:-2]
attrs = {}
attrs = dict()
attrs["input_dim"] = parameters[weight_name].shape[1]
attrs["output_dim"] = parameters[weight_name].shape[0]
linear_name = "linear{}".format(self.linear_index)
......
......@@ -18,7 +18,10 @@ from x2paddle.optimizer.pass_manager import PassManager
class GraphOptimizer(object):
def __init__(self):
self.passes = ["fc_fuse_pass", "constant_fuse_pass"]
self.passes = [
"fc_fuse_pass", "nn_adaptive_pool2d_fuse_pass",
"functional_adaptive_pool2d_fuse_pass", "constant_fuse_pass"
]
def optimize(self, graph):
for pass_name in self.passes:
......
......@@ -83,16 +83,13 @@ class PatternMatcher(object):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return False
# 当为控制流时的处理
if layer.kernel == "prim.if":
match_info = get_subgraph(pattern_layer.blocks[0],
layer.blocks[0], 0)
if match_info:
subgraph_id2layers.update(match_info)
else:
if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if len(pattern_layer.blocks) != len(layer.blocks):
return False
match_info = get_subgraph(pattern_layer.blocks[1],
layer.blocks[1], 0)
if match_info:
for i, b in enumerate(pattern_layer.blocks):
match_info = get_subgraph(pattern_layer.blocks[i],
layer.blocks[i], 0)
if match_info is not False:
subgraph_id2layers.update(match_info)
else:
return False
......@@ -101,6 +98,7 @@ class PatternMatcher(object):
return subgraph_id2layers
else:
return False
return subgraph_id2layers
for i, (layer_id, layer) in enumerate(graph.layers.items()):
match_info = get_subgraph(self.pattern, graph, i)
......@@ -152,16 +150,17 @@ def get_subgraph(prefix_layer_id, suffix_layer_id, graph):
class FuseBase(object):
def __init__(self):
self.pattern = PaddleGraph()
def __init__(self, graph_type):
self.pattern = PaddleGraph(graph_type=graph_type)
def operate(self, graph, match_kind="topo"):
parameters = graph.parameters
self.build_pattern()
self.perform_pattern_matcher(graph, match_kind)
for match in self.matches:
first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph)
self.insert_new_layer(subgraph, match)
self.insert_new_layer(subgraph, parameters, match)
self.delete_inter_layer(graph)
graph.build()
......@@ -183,6 +182,6 @@ class FuseBase(object):
param_name = layer.attrs["value"][8:-2]
if param_name in graph.parameters:
graph.parameters.pop(param_name)
if layer_id in graph.layers:
if layer_id in subgraph.layers:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
subgraph.layers.pop(layer_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册