提交 52c6ea75 编写于 作者: S SunAhong1993

modify optimizer

上级 c5f36434
...@@ -128,10 +128,30 @@ class PaddleGraph(object): ...@@ -128,10 +128,30 @@ class PaddleGraph(object):
for output in layer.outputs: for output in layer.outputs:
outputs_from_nodes[output] = layer_id outputs_from_nodes[output] = layer_id
# 将block的输出用于父图
if inputs is not None and outputs is not None and set(
layer.outputs).issubset(outputs):
if layer_id not in self.edges_out:
self.edges_out[layer_id] = list()
self.edges_out[layer_id].append(-1)
# 处理子图
if len(layer.blocks) > 0: if len(layer.blocks) > 0:
for block in layer.blocks: for block in layer.blocks:
block.build(layer.inputs, layer.outputs) block.build(layer.inputs, layer.outputs)
# 删除不必要的节点
invalid_list = list()
for layer_id, layer in self.layers.items():
if len(self.layers) > 1:
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception" \
and layer.kernel != "prim.warnings":
invalid_list.append(layer_id)
for layer_id in invalid_list:
self.layers.pop(layer_id)
if self.graph_type == "dygraph": if self.graph_type == "dygraph":
self.get_dygraph_inputs() self.get_dygraph_inputs()
if len(self.outputs) == 0: if len(self.outputs) == 0:
...@@ -244,7 +264,8 @@ class PaddleGraph(object): ...@@ -244,7 +264,8 @@ class PaddleGraph(object):
else: else:
self.gen_dygraph_code(save_dir) self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
self.dygraph2static(save_dir, input_shapes) #[[None, 3, 224, 224]]
# self.dygraph2static(save_dir, input_shapes) #[[None, 3, 224, 224]]
def dump_parameter(self, param_name, param, save_dir): def dump_parameter(self, param_name, param, save_dir):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
...@@ -367,13 +388,8 @@ class PaddleGraph(object): ...@@ -367,13 +388,8 @@ class PaddleGraph(object):
gen_head() gen_head()
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if len(self.layers) > 1: if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( ) or layer.kernel == "fluid.dygraph.base.to_variable":
layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception" \
and layer.kernel != "prim.warnings":
continue
if "paddle.nn" in layer.kernel or layer.kernel == "fluid.dygraph.base.to_variable":
line = "{}".format( line = "{}".format(
layer.outputs[0] layer.outputs[0]
) if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[ ) if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[
......
...@@ -3229,17 +3229,35 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -3229,17 +3229,35 @@ def aten_upsample_bilinear2d(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.13 # 处理输入0,即%x.13
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["input"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
# 处理输入1,即%4963 # 处理输入1,即%4963
if inputs_name[1] in mapper.attrs: if inputs_name[1] in mapper.attrs:
layer_attrs["out_shape"] = mapper.attrs[inputs_name[1]] layer_attrs["size"] = mapper.attrs[inputs_name[1]]
else: else:
mapper._check_input(graph, inputs_node[1], inputs_name[1], mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs) current_outputs)
layer_inputs["out_shape"] = inputs_name[1] layer_inputs["size"] = inputs_name[1]
current_inputs.append(inputs_name[1]) current_inputs.append(inputs_name[1])
graph.add_layer(
"prim.isinstance",
inputs={"input": inputs_name[1]},
outputs=[inputs_name[1] + "_isinstance"],
cls="paddle.fluid.Variable")
graph.add_layer(
"prim.if", {"input": inputs_name[1] + "_isinstance"},
outputs=[inputs_name[0] + "_if1"])
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(if_layer, graph_type="dygraph")
block.add_layer(
"prim.var2list",
inputs={"input": inputs_name[1]},
outputs=[inputs_name[1]])
if_layer.add_block(block)
block = PaddleGraph(if_layer, graph_type="dygraph")
if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[1]
# 处理输入2,即%5421 # 处理输入2,即%5421
if inputs_name[2] in mapper.attrs: if inputs_name[2] in mapper.attrs:
layer_attrs["align_corners"] = mapper.attrs[inputs_name[2]] layer_attrs["align_corners"] = mapper.attrs[inputs_name[2]]
...@@ -3261,10 +3279,10 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -3261,10 +3279,10 @@ def aten_upsample_bilinear2d(mapper, graph, node):
inputs=list_layer_inputs, inputs=list_layer_inputs,
outputs=[output_name + "_assert"], outputs=[output_name + "_assert"],
type="eq") type="eq")
layer_inputs["scale"] = inputs_name[3] layer_inputs["scale_factor"] = inputs_name[3]
layer_attrs["align_mode"] = 0 layer_attrs["align_mode"] = 0
graph.add_layer( graph.add_layer(
"fluid.layers.interpolate", "paddle.nn.functional.interpolate",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
**layer_attrs) **layer_attrs)
......
...@@ -442,7 +442,8 @@ def prim_shape(mapper, graph, node): ...@@ -442,7 +442,8 @@ def prim_shape(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.shape", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer(
"fluid.layers.shape", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs return current_inputs, current_outputs
......
...@@ -172,10 +172,11 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]): ...@@ -172,10 +172,11 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
block = layer.blocks[1] block = layer.blocks[1]
if len(block.layers) > 0: if len(block.layers) > 0:
line = "else:"
forward_func.extend(gen_codes([line], indent=indent))
b_init_lines, b_forward_lines = block.gen_dygraph_code( b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1) indent=indent + 1)
if len(b_forward_lines) != 0:
line = "else:"
forward_func.extend(gen_codes([line], indent=indent))
init_func.extend(b_init_lines) init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
...@@ -191,6 +192,13 @@ def prim_is(layer, indent=1, init_func=[], forward_func=[]): ...@@ -191,6 +192,13 @@ def prim_is(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_isinstance(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = isinstance({}, {})".format(layer.outputs[0],
get_value(layer, "input"),
layer.attrs["cls"])
forward_func.extend(gen_codes([line], indent=indent))
def prim_isnot(layer, indent=1, init_func=[], forward_func=[]): def prim_isnot(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} is not {}".format(layer.outputs[0], line = "{} = {} is not {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "x"),
...@@ -370,6 +378,12 @@ def prim_type(layer, indent=1, init_func=[], forward_func=[]): ...@@ -370,6 +378,12 @@ def prim_type(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_var2list(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}.numpy().tolist()".format(layer.outputs[0],
get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_warnings(layer, indent=1, init_func=[], forward_func=[]): def prim_warnings(layer, indent=1, init_func=[], forward_func=[]):
lines = ["import warnings"] lines = ["import warnings"]
line = "warnings.warn({}, stacklevel={})".format( line = "warnings.warn({}, stacklevel={})".format(
......
...@@ -81,6 +81,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -81,6 +81,7 @@ class PyTorchOpMapper(OpMapper):
node = ivalue.node() node = ivalue.node()
if str(ivalue.type()) != "Tensor": if str(ivalue.type()) != "Tensor":
graph.set_name(str(ivalue.type()).split(".")[-1]) graph.set_name(str(ivalue.type()).split(".")[-1])
continue
inputs, outputs = self.data(graph, node, ivalue.unique()) inputs, outputs = self.data(graph, node, ivalue.unique())
# 转换中间节点 # 转换中间节点
for node in script_graph.nodes(): for node in script_graph.nodes():
......
...@@ -25,215 +25,122 @@ class BatchNorm2dFuser(FuseBase): ...@@ -25,215 +25,122 @@ class BatchNorm2dFuser(FuseBase):
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。 """ 描述需要替换的batchnorm2d图结构。
batchnorm2d层模式python实现代码示例: batchnorm2d层模式python实现代码示例:
x2214 = fluid.layers.shape(x2207) x336 = fluid.layers.shape(input=x334)
x2214 = len(x2214) x336 = len(x336)
x2215 = x2214 != x2213 x337 = x336 != 4
if x2215 : if x337 :
raise RaiseException(x2212) raise RaiseException('Exception')
if x2218 : if False :
x2220 = self.x2220 x351 = fluid.layers.shape(input=x334)
x2221 = x2220 + x2209 x352 = x351[0]
self.x2220 = x2221 x353 = len(x351)
x2227 = False x354 = x353 - 2
if x2227 : x357 = x352
x2230 = fluid.layers.shape(x2207.shape) for _x356 in range(x354):
x2231 = 'Exception' x358 = _x356 + 2
x2236 = x2230[x2233] x359 = x351[x358]
x2237 = len(x2230) x360 = x357 * x359
x2238 = x2237 - x2234 x355 = x360
x2241 = x2236 x361 = x355 == 1
for _x2240 in range(x2238): if x361 :
x2242 = _x2240 + x2234 raise RaiseException('Exception')
x2243 = x2230[x2242] x364 = self.batchnorm7(x334)
x2244 = x2241 * x2243
x2239 = x2244
x2245 = x2239 == x2235
if x2245 :
raise RaiseException(x2231)
x2248 = self.batchnorm41(x2207)
""" """
def gen_name(id): def gen_name(id):
return "x" + str(id) return "x" + str(id)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(0)], value=1)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(1)], value=0.1)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(2)], value=0.001)
# self.pattern.add_layer(
# "prim.constant",
# inputs={},
# outputs=[gen_name(3)],
# value="Exception")
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(4)], value=4)
self.pattern.add_layer( self.pattern.add_layer(
"fluid.layers.shape", "fluid.layers.shape",
inputs={'input': "bn-input-0"}, inputs={'input': "bn-input-0"},
outputs=[gen_name(5)]) outputs=[gen_name(0)])
self.pattern.add_layer( self.pattern.add_layer(
"prim.len", inputs={'input': gen_name(5)}, outputs=[gen_name(5)]) "prim.len", inputs={'input': gen_name(0)}, outputs=[gen_name(0)])
self.pattern.add_layer( self.pattern.add_layer(
"prim.ne", "prim.ne", inputs={"x": gen_name(0)}, outputs=[gen_name(1)], y=4)
inputs={"x": gen_name(5), self.pattern.add_layer("prim.if", {'input': gen_name(1)}, [gen_name(2)])
"y": "bn-input-9"},
outputs=[gen_name(6)])
self.pattern.add_layer("prim.if", {'input': gen_name(6)}, [gen_name(7)])
if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer1, graph_type="dygraph") pattern_block0 = PaddleGraph(if_layer1, graph_type="dygraph")
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.exception", "prim.exception",
inputs={"input": "bn-input-1"}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(3)],
if_layer1.inputs["input-0"] = "bn-input-1" input="Exception")
if_layer1.add_block(pattern_block0) if_layer1.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer1, graph_type="dygraph") pattern_block1 = PaddleGraph(if_layer1, graph_type="dygraph")
if_layer1.add_block(pattern_block1) if_layer1.add_block(pattern_block1)
# self.pattern.add_layer( self.pattern.add_layer("prim.if", {}, [gen_name(4)], input=False)
# "prim.constant", inputs={}, outputs=[gen_name(9)], value=False)
self.pattern.add_layer("prim.if", {'input': "bn-input-2"},
[gen_name(10)])
if_layer2 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] if_layer2 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer2, graph_type="dygraph") pattern_block0 = PaddleGraph(if_layer2, graph_type="dygraph")
pattern_block0.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[gen_name(11)],
value="params[{}]".format(string(gen_name(11))))
pattern_block0.add_layer(
"prim.add",
inputs={"x": gen_name(11),
"y": "bn-input-3"},
outputs=[gen_name(12)])
pattern_block0.add_layer(
"prim.set_attr",
inputs={"input": gen_name(12)},
outputs=["self." + gen_name(11)])
if_layer2.inputs["input-0"] = "bn-input-3"
if_layer2.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer2, graph_type="dygraph")
if_layer2.add_block(pattern_block1)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(13)], value=True)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(14)], value=False)
self.pattern.add_layer("prim.if", {'input': "bn-input-4"},
[gen_name(15)])
if_layer3 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer3, graph_type="dygraph")
pattern_block0.add_layer( pattern_block0.add_layer(
"fluid.layers.shape", "fluid.layers.shape",
inputs={'input': "bn-input-0"}, inputs={'input': "bn-input-0"},
outputs=[gen_name(16)]) outputs=[gen_name(5)])
# pattern_block0.add_layer(
# "prim.constant",
# inputs={},
# outputs=[gen_name(17)],
# value="Exception")
# pattern_block0.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(18)], value=True)
# pattern_block0.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(19)], value=0)
# pattern_block0.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(20)], value=2)
# pattern_block0.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(21)], value=1)
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.getitem", "prim.getitem",
inputs={"list": gen_name(16), inputs={"list": gen_name(5)},
"index": "bn-input-6"}, outputs=[gen_name(6)],
outputs=[gen_name(22)]) index=0)
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.len", inputs={"input": gen_name(16)}, outputs=[gen_name(23)]) "prim.len", inputs={"input": gen_name(5)}, outputs=[gen_name(7)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.sub", "prim.sub", inputs={"x": gen_name(7)}, outputs=[gen_name(8)], y=2)
inputs={"x": gen_name(23),
"y": "bn-input-7"},
outputs=[gen_name(24)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.equal", "prim.equal", inputs={"input": gen_name(6)}, outputs=[gen_name(9)])
inputs={"input": gen_name(22)},
outputs=[gen_name(25)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.loop", "prim.loop",
inputs={"input": gen_name(24)}, inputs={"input": gen_name(8)},
outputs=[gen_name(26), gen_name(27)]) outputs=[gen_name(8.1), gen_name(10)])
loop_layer = pattern_block0.layers[list(pattern_block0.layers.keys())[ loop_layer = pattern_block0.layers[list(pattern_block0.layers.keys())[
-1]] -1]]
pattern_block0_block0 = PaddleGraph(loop_layer, graph_type="dygraph") pattern_block0_block0 = PaddleGraph(loop_layer, graph_type="dygraph")
pattern_block0_block0.add_layer( pattern_block0_block0.add_layer(
"prim.add", "prim.add", inputs={"x": gen_name(10)}, outputs=[gen_name(11)], y=2)
inputs={"x": gen_name(27),
"y": "bn-input-7"},
outputs=[gen_name(28)])
pattern_block0_block0.add_layer( pattern_block0_block0.add_layer(
"prim.getitem", "prim.getitem",
inputs={"list": gen_name(16), inputs={"list": gen_name(5),
"index": gen_name(28)}, "index": gen_name(11)},
outputs=[gen_name(29)]) outputs=[gen_name(12)])
pattern_block0_block0.add_layer( pattern_block0_block0.add_layer(
"prim.mul", "prim.mul",
inputs={"x": gen_name(25), inputs={"x": gen_name(9),
"y": gen_name(29)}, "y": gen_name(12)},
outputs=[gen_name(30)]) outputs=[gen_name(13)])
pattern_block0_block0.add_layer( pattern_block0_block0.add_layer(
"prim.equal", "prim.equal",
inputs={"input": gen_name(30)}, inputs={"input": gen_name(13)},
outputs=[gen_name(26)]) outputs=[gen_name(8.1)])
loop_layer.inputs["input-1"] = "bn-input-7" loop_layer.inputs["input-1"] = gen_name(5)
loop_layer.inputs["input-2"] = gen_name(16) loop_layer.inputs["input-2"] = gen_name(9)
loop_layer.inputs["input-3"] = gen_name(25)
loop_layer.add_block(pattern_block0_block0) loop_layer.add_block(pattern_block0_block0)
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.eq", "prim.eq", inputs={"x": gen_name(8.1)}, outputs=[gen_name(14)], y=1)
inputs={"x": gen_name(26),
"y": "bn-input-8"},
outputs=[gen_name(31)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.if", inputs={"input": gen_name(31)}, outputs=[gen_name(32)]) "prim.if", inputs={"input": gen_name(14)}, outputs=[gen_name(15)])
if_layer31 = pattern_block0.layers[list(pattern_block0.layers.keys())[ if_layer21 = pattern_block0.layers[list(pattern_block0.layers.keys())[
-1]] -1]]
pattern_block0_block0 = PaddleGraph(if_layer31, graph_type="dygraph") pattern_block0_block0 = PaddleGraph(if_layer21, graph_type="dygraph")
pattern_block0_block0.add_layer( pattern_block0_block0.add_layer(
"prim.exception", "prim.exception",
inputs={"input": "bn-input-5"}, inputs={},
outputs=[gen_name(33)]) outputs=[gen_name(15)],
if_layer31.inputs["input-0"] = "bn-input-5" input="Exception")
if_layer31.add_block(pattern_block0_block0) if_layer21.add_block(pattern_block0_block0)
pattern_block0_block1 = PaddleGraph(if_layer31, graph_type="dygraph") pattern_block0_block1 = PaddleGraph(if_layer21, graph_type="dygraph")
if_layer31.add_block(pattern_block0_block1) if_layer21.add_block(pattern_block0_block1)
if_layer3.add_block(pattern_block0) if_layer2.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer3, graph_type="dygraph") pattern_block1 = PaddleGraph(if_layer2, graph_type="dygraph")
if_layer3.add_block(pattern_block1) if_layer2.add_block(pattern_block1)
if_layer3.inputs["input-0"] = "bn-input-5" if_layer2.inputs["input-0"] = "bn-input-0"
if_layer3.inputs["input-1"] = "bn-input-6"
if_layer3.inputs["input-2"] = "bn-input-7"
if_layer3.inputs["input-3"] = "bn-input-7"
if_layer3.inputs["input-4"] = "bn-input-8"
if_layer3.inputs["input-5"] = "bn-input-0"
self.pattern.add_layer( self.pattern.add_layer(
"paddle.nn.BatchNorm", "paddle.nn.BatchNorm",
inputs={"input": "bn-input-0"}, inputs={"input": "bn-input-0"},
outputs=[gen_name(34), gen_name(35)], outputs=[gen_name(16), gen_name(17)],
is_test=True, is_test=True,
num_channels=160, num_channels=160,
momentum=0.1, momentum=0.1,
epsilon=0.001) epsilon=0.001)
self.pattern.build(inputs={ self.pattern.build(inputs={"input-0": "bn-input-0"})
"input-0": "bn-input-0",
"input-1": "bn-input-1",
"input-2": "bn-input-2",
"input-3": "bn-input-3",
"input-4": "bn-input-4",
"input-5": "bn-input-5",
"input-6": "bn-input-6",
"input-7": "bn-input-7",
"input-8": "bn-input-8",
"input-9": "bn-input-9"
})
def insert_new_layer(self, graph, parameters, matches): def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches) new_layer = self.gen_new_layer(parameters, matches)
...@@ -241,6 +148,10 @@ class BatchNorm2dFuser(FuseBase): ...@@ -241,6 +148,10 @@ class BatchNorm2dFuser(FuseBase):
graph.layers[new_layer_id] = new_layer graph.layers[new_layer_id] = new_layer
matches.pop(new_layer_id) matches.pop(new_layer_id)
# for layer in matches.values():
# print(layer.outputs)
# print("-------")
def gen_new_layer(self, parameters, matches): def gen_new_layer(self, parameters, matches):
layers_id = list(matches.keys()) layers_id = list(matches.keys())
layer = matches[layers_id[-1]] layer = matches[layers_id[-1]]
......
...@@ -28,7 +28,7 @@ class FcFuser(FuseBase): ...@@ -28,7 +28,7 @@ class FcFuser(FuseBase):
fc层模式python实现代码示例: fc层模式python实现代码示例:
x133 = x128.shape x133 = x128.shape
x133 = len(x133) x133 = len(x133)
x134 = x133 == x131 x134 = x133 == 2
if x134 : if x134 :
classifier_6_weight = self.classifier_6_weight classifier_6_weight = self.classifier_6_weight
x136 = fluid.layers.transpose(x=classifier_6_weight, perm=[1, 0]) x136 = fluid.layers.transpose(x=classifier_6_weight, perm=[1, 0])
...@@ -55,9 +55,9 @@ class FcFuser(FuseBase): ...@@ -55,9 +55,9 @@ class FcFuser(FuseBase):
"prim.len", inputs={'input': gen_name(2)}, outputs=[gen_name(2)]) "prim.len", inputs={'input': gen_name(2)}, outputs=[gen_name(2)])
self.pattern.add_layer( self.pattern.add_layer(
"prim.eq", "prim.eq",
inputs={"eq0": gen_name(2), inputs={"eq0": gen_name(2)},
"eq1": "fc-input-1"}, outputs=[gen_name(3)],
outputs=[gen_name(3)]) eq1=2)
self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)]) self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)])
self.pattern.outputs.append(gen_name(4)) self.pattern.outputs.append(gen_name(4))
if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
...@@ -122,9 +122,7 @@ class FcFuser(FuseBase): ...@@ -122,9 +122,7 @@ class FcFuser(FuseBase):
"prim.equal", inputs={'input': gen_name(13)}, "prim.equal", inputs={'input': gen_name(13)},
outputs=[gen_name(4)]) outputs=[gen_name(4)])
if_layer1.add_block(pattern_block1) if_layer1.add_block(pattern_block1)
self.pattern.build( self.pattern.build(inputs={"input-0": "fc-input-0"})
inputs={"input-0": "fc-input-0",
"input-1": "fc-input-1"})
def insert_new_layer(self, graph, parameters, matches): def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches) new_layer = self.gen_new_layer(parameters, matches)
......
...@@ -19,9 +19,14 @@ from x2paddle.optimizer.pass_manager import PassManager ...@@ -19,9 +19,14 @@ from x2paddle.optimizer.pass_manager import PassManager
class GraphOptimizer(object): class GraphOptimizer(object):
def __init__(self): def __init__(self):
self.passes = [ self.passes = [
"interpolate_bilinear_fuse_pass", "fc_fuse_pass", "constant_fuse_pass",
"adaptive_pool2d_fuse_pass", "batchnorm2d_fuse_pass", "batchnorm2d_fuse_pass",
"constant_fuse_pass", "reshape_fuse_pass", "dropout_fuse_pass" "interpolate_bilinear_fuse_pass",
"fc_fuse_pass",
# "interpolate_bilinear_fuse_pass",
# "fc_fuse_pass",
# "adaptive_pool2d_fuse_pass", "batchnorm2d_fuse_pass",
# "constant_fuse_pass", "reshape_fuse_pass", "dropout_fuse_pass"
] ]
def optimize(self, graph): def optimize(self, graph):
......
...@@ -34,7 +34,7 @@ class PatternMatcher(object): ...@@ -34,7 +34,7 @@ class PatternMatcher(object):
并将子图的id以拓扑排序存放到subgraph_id2layers。 并将子图的id以拓扑排序存放到subgraph_id2layers。
""" """
def get_subgraph(pattern, graph, start_index): def get_subgraph(pattern, graph, start_index, is_subblock=False):
pattern_index = 0 pattern_index = 0
pattern_id2layers = pattern.get_global_layers() pattern_id2layers = pattern.get_global_layers()
pattern_ids = list(pattern_id2layers.keys()) pattern_ids = list(pattern_id2layers.keys())
...@@ -49,13 +49,19 @@ class PatternMatcher(object): ...@@ -49,13 +49,19 @@ class PatternMatcher(object):
# 判断输入连接是否一致 # 判断输入连接是否一致
if layer_id in graph.edges_in: if layer_id in graph.edges_in:
if pattern_layer_id not in pattern.edges_in: if pattern_layer_id not in pattern.edges_in:
print("1--") if pattern_index == 0 or is_subblock:
return False return False
else:
subgraph_id2layers.pop(layer_id)
continue
else: else:
if len(graph.edges_in[layer_id]) != len( if len(graph.edges_in[layer_id]) != len(
pattern.edges_in[pattern_layer_id]): pattern.edges_in[pattern_layer_id]):
print("2--") if pattern_index == 0 or is_subblock:
return False return False
else:
subgraph_id2layers.pop(layer_id)
continue
layer_in = graph.edges_in[layer_id] layer_in = graph.edges_in[layer_id]
pattern_layer_in = pattern.edges_in[pattern_layer_id] pattern_layer_in = pattern.edges_in[pattern_layer_id]
for i in range(len(layer_in)): for i in range(len(layer_in)):
...@@ -70,16 +76,22 @@ class PatternMatcher(object): ...@@ -70,16 +76,22 @@ class PatternMatcher(object):
# 判断pattern输入在pattern_ids的索引 # 判断pattern输入在pattern_ids的索引
# 和graph输入在subgraph_ids的索引一致 # 和graph输入在subgraph_ids的索引一致
continue continue
print("3--") if pattern_index == 0 or is_subblock:
return False return False
else:
subgraph_id2layers.pop(layer_id)
continue
# 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效) # 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效)
if layer_id in graph.edges_out: if layer_id in graph.edges_out:
if pattern_layer_id not in pattern.edges_out: if pattern_layer_id not in pattern.edges_out:
if not set(pattern_layer.outputs).issubset( if not set(pattern_layer.outputs).issubset(
pattern.outputs): pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的 # 若pattern当前layer的输出是pattern的输出,则是正确的
print("4--") if pattern_index == 0 or is_subblock:
return False return False
else:
subgraph_id2layers.pop(layer_id)
continue
else: else:
if len(graph.edges_out[layer_id]) != len( if len(graph.edges_out[layer_id]) != len(
pattern.edges_out[pattern_layer_id]): pattern.edges_out[pattern_layer_id]):
...@@ -87,27 +99,49 @@ class PatternMatcher(object): ...@@ -87,27 +99,49 @@ class PatternMatcher(object):
if not set(pattern_layer.outputs).issubset( if not set(pattern_layer.outputs).issubset(
pattern.outputs): pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的 # 若pattern当前layer的输出是pattern的输出,则是正确的
# print("5--") if pattern_index == 0 or is_subblock:
return False return False
else:
subgraph_id2layers.pop(layer_id)
continue
# 当为控制流时的处理 # 当为控制流时的处理
if layer.kernel == "prim.if" or layer.kernel == "prim.loop": if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if len(pattern_layer.blocks) != len(layer.blocks): if len(pattern_layer.blocks) != len(layer.blocks):
print("6--") if pattern_index == 0 or is_subblock:
return False return False
else:
subgraph_id2layers.pop(layer_id)
continue
is_subblock_match = True
for i, b in enumerate(pattern_layer.blocks): for i, b in enumerate(pattern_layer.blocks):
match_info = get_subgraph(pattern_layer.blocks[i], match_info = get_subgraph(
layer.blocks[i], 0) pattern_layer.blocks[i],
layer.blocks[i],
0,
is_subblock=True)
if match_info is not False: if match_info is not False:
subgraph_id2layers.update(match_info) subgraph_id2layers.update(match_info)
else: else:
print("7--") is_subblock_match = False
break
if not is_subblock_match:
if pattern_index == 0 or is_subblock:
return False return False
else:
index = list(subgraph_id2layers.keys()).index(
layer_id)
for key in list(subgraph_id2layers.keys())[
index:]:
subgraph_id2layers.pop(key)
continue
pattern_index += 1 pattern_index += 1
if pattern_index == len(pattern.layers): if pattern_index == len(pattern.layers):
return subgraph_id2layers return subgraph_id2layers
else: else:
if pattern_index == 0: if pattern_index == 0 or is_subblock:
return False return False
else:
continue
if pattern_index == len(pattern.layers): if pattern_index == len(pattern.layers):
return subgraph_id2layers return subgraph_id2layers
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册