提交 31206399 编写于 作者: S SunAhong1993

add tf optimizer

上级 328ac1a2
......@@ -128,19 +128,25 @@ def tf2paddle(model_path,
else:
from x2paddle.op_mapper.static.tf2paddle.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model)
mapper.paddle_graph.build()
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
bias_opt.run(program)
batch_norm_opt.run(program)
transpose_opt.run(program)
if paddle_type == "dygraph":
from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer(source_frame="tf", paddle_type=paddle_type)
graph_opt.optimize(mapper.paddle_graph)
else:
from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
bias_opt.run(program)
batch_norm_opt.run(program)
transpose_opt.run(program)
mapper.paddle_graph.gen_model(save_dir)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transpose_elimination import Dygraph_TransposeElimination
from .transpose_eliminate_pass import Dygraph_TransposeEliminatePass
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.elimination.dygraph import Dygraph_TransposeElimination
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Dygraph_TransposeEliminatePass(Pass):
name = "transpose_eliminate_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Dygraph_TransposeElimination()
fuser.operate(graph)
# 用于注册
transpose_eliminate_pass = Dygraph_TransposeEliminatePass()
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class Dygraph_TransposeElimination(FuseBase):
def __init__(self):
super(Dygraph_TransposeElimination, self).__init__(graph_type="dygraph")
self.direct_layers = [
'paddle.nn.ReLU', 'paddle.nn.ReLU6', 'paddle.abs',
'paddle.nn.Sigmoid', 'paddle.exp', 'paddle.rsqrt',
'paddle.nn.Swish', 'paddle.nn.Tanh',
'paddle.nn.Softplus', 'paddle.nn.LeakyReLU',
'paddle.nn.Softmax', 'paddle.erf', 'paddle.square'
]
self.elementwise_layers = [
'paddle.add', 'fluid.layers.elementwise_sub',
'paddle.multiply', 'paddle.divide'
]
# self.reduce_layers = []
self.reduce_layers = [
'paddle.mean', 'paddle.all',
'paddle.max', 'paddle.any',
'paddle.sum', 'paddle.prod'
]
def get_transpose_num(self, graph):
count = 0
for layer_id, layer in graph.layers.items():
if layer.kernel == "paddle.transpose":
count += 1
return count
def operate(self, graph):
total_layer_num = len(graph.layers)
scanned_layers = set()
optimized_transpose_layers = list()
optimized_reduce_layers = list()
optimized_concat_layers = list()
optimized_elementwise_layers = list()
def strip_transpose(_graph):
layers = copy.deepcopy(_graph.layers)
for layer_id, layer in layers.items():
if layer_id in scanned_layers:
continue
scanned_layers.add(layer_id)
percent = round(len(scanned_layers) / total_layer_num * 100, 2)
print("\rOptimize Transpose Layers...{}%".format(
percent))
if layer.kernel != "paddle.transpose":
continue
if layer.attrs["perm"] != [0, 2, 3, 1]:
continue
transpose_layers = list()
propagate_layers = list()
reduce_layers = list()
concat_layers = list()
# 此elementwise_layers专用于存储shape(4) + shape(1)的形式layer
elementwise_layers = list()
can_be_optimized = True
for out in _graph.edges_out.get(layer_id, []):
if _graph.layers[out].kernel == "paddle.transpose":
if _graph.layers[out].attrs["perm"] != [0, 3, 1, 2]:
can_be_optimized = False
break
transpose_layers.append(out)
elif _graph.layers[out].kernel in self.elementwise_layers:
propagate_layers.append(out)
elif _graph.layers[out].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[out].kernel.startswith("paddle.nn.") else 0
if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False
break
propagate_layers.append(out)
elif _graph.layers[out].kernel in self.reduce_layers:
if _graph.layers[out].outputs[0] in _graph.outputs:
can_be_optimized = False
break
if not _graph.layers[out].attrs.get('keepdim', False):
can_be_optimized = False
break
propagate_layers.append(out)
reduce_layers.append(out)
elif _graph.layers[out].kernel == "paddle.concat":
if _graph.layers[out].outputs[0] in _graph.outputs:
can_be_optimized = False
break
propagate_layers.append(out)
concat_layers.append(out)
else:
can_be_optimized = False
break
visited_layers = set()
while len(propagate_layers) > 0 and can_be_optimized:
current_id = propagate_layers.pop(0)
visited_layers.add(current_id)
for out in _graph.edges_out.get(current_id, []):
if _graph.layers[
out].kernel == "paddle.transpose":
if _graph.layers[out].attrs["perm"] != [0, 3, 1, 2]:
can_be_optimized = False
break
transpose_layers.append(out)
elif _graph.layers[
out].kernel in self.elementwise_layers:
if _graph.layers[out].outputs[0] in _graph.outputs:
can_be_optimized = False
break
if out not in visited_layers:
propagate_layers.append(out)
elif _graph.layers[out].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[out].kernel.startswith("paddle.nn.") else 0
if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False
break
if out not in visited_layers:
propagate_layers.append(out)
elif _graph.layers[out].kernel in self.reduce_layers:
if _graph.layers[out].outputs[0] in _graph.outputs:
can_be_optimized = False
break
if not _graph.layers[out].attrs.get('keepdim',
False):
can_be_optimized = False
break
if out not in visited_layers:
propagate_layers.append(out)
reduce_layers.append(out)
elif _graph.layers[out].kernel == "paddle.concat":
if _graph.layers[out].outputs[0] in _graph.outputs:
can_be_optimized = False
break
if out not in visited_layers:
propagate_layers.append(out)
concat_layers.append(out)
else:
can_be_optimized = False
break
for ipt in _graph.edges_in.get(current_id, []):
if _graph.layers[
current_id].kernel in self.elementwise_layers:
try:
x_shape = _graph.layers[
current_id].input_shapes['x']
y_shape = _graph.layers[
current_id].input_shapes['y']
if _graph.layers[ipt].outputs[
0] == _graph.layers[current_id].inputs[
'x']:
if len(x_shape) <= 1:
elementwise_layers.append(current_id)
continue
elif _graph.layers[ipt].outputs[
0] == _graph.layers[current_id].inputs[
'y']:
if len(y_shape) <= 1:
elementwise_layers.append(current_id)
continue
else:
raise Exception(
"Unexcepted situation happend while optimizing transpose"
)
except Exception as e:
can_be_optimized = False
break
if _graph.layers[
ipt].kernel == "paddle.transpose":
if _graph.layers[ipt].attrs["perm"] != [0, 2, 3, 1]:
can_be_optimized = False
break
if ipt not in visited_layers:
transpose_layers.append(ipt)
elif _graph.layers[
ipt].kernel in self.elementwise_layers:
if _graph.layers[ipt].outputs[0] in _graph.outputs:
can_be_optimized = False
break
if ipt not in visited_layers:
propagate_layers.append(ipt)
elif _graph.layers[ipt].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[ipt].kernel.startswith("paddle.nn.") else 0
if _graph.layers[ipt].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False
break
if ipt not in visited_layers:
propagate_layers.append(ipt)
elif _graph.layers[ipt].kernel in self.reduce_layers:
if _graph.layers[ipt].outputs[0] in _graph.outputs:
can_be_optimized = False
break
if not _graph.layers[ipt].attrs.get('keepdim',
False):
can_be_optimized = False
break
if ipt not in visited_layers:
propagate_layers.append(ipt)
reduce_layers.append(ipt)
elif _graph.layers[ipt].kernel == "paddle.concat":
if _graph.layers[ipt].outputs[0] in _graph.outputs:
can_be_optimized = False
break
if ipt not in visited_layers:
propagate_layers.append(ipt)
concat_layers.append(ipt)
else:
can_be_optimized = False
break
if not can_be_optimized:
break
if not can_be_optimized:
continue
transpose_layers.append(layer_id)
transpose_layers = list(set(transpose_layers))
for l in transpose_layers:
if graph.layers[l].outputs[0] in graph.outputs:
can_be_optimized = False
break
if not can_be_optimized:
continue
for l in transpose_layers:
self.delete_layer_with_associated(_graph, l)
optimized_transpose_layers.extend(transpose_layers)
optimized_reduce_layers.extend(reduce_layers)
optimized_concat_layers.extend(concat_layers)
optimized_elementwise_layers.extend(elementwise_layers)
return True
return False
before_transpose_num = self.get_transpose_num(graph)
opt_graph = copy.deepcopy(graph)
total_layer_num = len(opt_graph.layers)
while strip_transpose(opt_graph):
pass
for layer_id in list(set(optimized_transpose_layers)):
self.delete_layer_with_associated(graph, layer_id)
for layer_id in list(set(optimized_reduce_layers)):
dim = graph.layers[layer_id].attrs.get('dim', None)
if dim is not None:
for i in range(len(dim)):
dim[i] = [0, 2, 3, 1][dim[i]]
graph.layers[layer_id].attrs['dim'] = dim
for layer_id in list(set(optimized_concat_layers)):
axis = graph.layers[layer_id].attrs.get('axis', 0)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
for layer_id in list(set(optimized_elementwise_layers)):
axis = graph.layers[layer_id].attrs.get('axis', -1)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
current_transpose_num = self.get_transpose_num(graph)
print(
"\nTranspose layers optimized, before: transpose_num={}, after: transpose_num={}".
format(before_transpose_num, current_transpose_num))
......@@ -20,6 +20,8 @@ from .bn_scale_fuser import Dygraph_BNScaleFuser
from .bn_scale_fuse_pass import Dygraph_BNScaleFusePass
from .constant_fuser import Dygraph_ConstantFuser
from .constant_fuse_pass import Dygraph_ConstantFusePass
from .conv2d_add_fuser import Dygraph_Conv2D_AddFuser
from .conv2d_add_fuse_pass import Dygraph_Conv2D_AddFusePass
from .dropout_fuser import Dygraph_DropoutFuser
from .dropout_fuse_pass import Dygraph_DropoutFusePass
from .fc_fuser import Dygraph_FcFuser
......@@ -28,3 +30,5 @@ from .interpolate_bilinear_fuser import Dygraph_InterpolateBilinearFuser
from .interpolate_bilinear_fuse_pass import Dygraph_InterpolateBilinearFusePass
from .reshape_fuser import Dygraph_ReshapeFuser
from .reshape_fuse_pass import Dygraph_ReshapeFusePass
from .tf_batchnorm_fuser import Dygraph_TF_BatchNormFuser
from .tf_batchnorm_fuse_pass import Dygraph_TF_BatchNormFusePass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_Conv2D_AddFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Dygraph_Conv2D_AddFusePass(Pass):
name = "dygraph_conv2d_add_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Dygraph_Conv2D_AddFuser()
fuser.operate(graph, match_kind="edge")
# 用于注册
dygraph_conv2d_add_fuse_pass = Dygraph_Conv2D_AddFusePass()
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class Dygraph_Conv2D_AddFuser(FuseBase):
def __init__(self):
super(Dygraph_Conv2D_AddFuser, self).__init__(graph_type="dygraph")
self.patterns = list()
def build_pattern(self):
""" 描述需要替换的conv2d+add图结构。
conv2d+add层模式python实现代码示例:
模式一:
MobilenetV1_Logits_Conv2d_1c_1x1_biases = self.MobilenetV1_Logits_Conv2d_1c_1x1_biases
conv2d_transpose_14 = paddle.transpose(x=MobilenetV1_Logits_AvgPool_1a_AvgPool, perm=[0, 3, 1, 2])
MobilenetV1_Logits_Conv2d_1c_1x1_Conv2D = self.conv27(conv2d_transpose_14)
MobilenetV1_Logits_Conv2d_1c_1x1_Conv2D = paddle.transpose(x=MobilenetV1_Logits_Conv2d_1c_1x1_Conv2D, perm=[0, 2, 3, 1])
MobilenetV1_Logits_Conv2d_1c_1x1_BiasAdd = paddle.add(x=MobilenetV1_Logits_Conv2d_1c_1x1_Conv2D, y=MobilenetV1_Logits_Conv2d_1c_1x1_biases)
模式二:
MobilenetV1_Logits_Conv2d_1c_1x1_biases = self.MobilenetV1_Logits_Conv2d_1c_1x1_biases
MobilenetV1_Logits_Conv2d_1c_1x1_Conv2D = self.conv27(conv2d_transpose_14)
MobilenetV1_Logits_Conv2d_1c_1x1_BiasAdd = paddle.add(x=MobilenetV1_Logits_Conv2d_1c_1x1_Conv2D, y=MobilenetV1_Logits_Conv2d_1c_1x1_biases)
"""
def gen_name(id):
return "x" + str(id)
pattern = PaddleGraph(graph_type="dygraph")
pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(0)])
pattern.add_layer(
kernel="paddle.transpose",
inputs={"x": "conv-input-0"},
outputs=[gen_name(1)],
perm=[0, 3, 1, 2])
pattern.add_layer(
kernel="paddle.nn.Conv2D",
inputs={"input": gen_name(1)},
outputs=[gen_name(2)])
pattern.add_layer(
kernel="paddle.transpose",
inputs={"x": gen_name(2)},
outputs=[gen_name(2)],
perm=[0, 2, 3, 1])
pattern.add_layer(
kernel="paddle.add",
inputs={"x": gen_name(2),
"y": gen_name(0)},
outputs=[gen_name(3)])
pattern.build(inputs={"input-0": "conv-input-0", })
self.patterns.append(pattern)
pattern = PaddleGraph(graph_type="dygraph")
pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(0)])
pattern.add_layer(
kernel="paddle.nn.Conv2D",
inputs={"input": "conv-input-0"},
outputs=[gen_name(1)])
pattern.add_layer(
kernel="paddle.add",
inputs={"x": gen_name(1),
"y": gen_name(0)},
outputs=[gen_name(2)])
pattern.build(inputs={"input-0": "conv-input-0", })
self.patterns.append(pattern)
def insert_new_layer(self, graph, parameters, matches):
self.gen_new_layer(matches, graph)
matches_copy = copy.deepcopy(matches)
for layer_id, layer in matches_copy.items():
if layer.kernel not in ["self.create_parameter", "paddle.add"]:
matches.pop(layer_id)
def gen_new_layer(self, matches, graph):
is_transpose = False
for layer_id, layer in matches.items():
if layer.kernel == "self.create_parameter":
bias_name = layer.attrs["attr"]
if layer.kernel == "paddle.transpose":
is_transpose = True
if layer.kernel == "paddle.add":
output_name = layer.outputs[0]
if layer.kernel == "paddle.nn.Conv2D":
conv_id = layer_id
for layer_id, layer in matches.items():
if layer.kernel == "paddle.nn.functional.conv2d_transpose":
layer.bias = bias_name
if not is_transpose:
layer.outputs[0] = output_name
if layer.kernel == "paddle.nn.Conv2D":
layer.attrs["bias_attr"] = bias_name
if not is_transpose:
layer.outputs[1] = output_name
if layer.kernel == "paddle.transpose":
if conv_id in graph.edges_in[layer_id]:
layer.outputs[0] = output_name
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_TF_BatchNormFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Dygraph_TF_BatchNormFusePass(Pass):
name = "dygraph_tf_batchnorm_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Dygraph_TF_BatchNormFuser()
fuser.operate(graph, match_kind="edge")
# 用于注册
dygraph_tf_batchnorm_fuse_pass = Dygraph_TF_BatchNormFusePass()
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
from collections import OrderedDict
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class Dygraph_TF_BatchNormFuser(FuseBase):
def __init__(self):
self.bn_index = 0
super(Dygraph_TF_BatchNormFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的batchnorm图结构。
batchnorm层模式python实现代码示例:
"""
def gen_name(id):
return "x" + str(id)
self.pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(0)])
self.pattern.add_layer(
"paddle.full",
inputs={},
outputs=[gen_name(1)],
shape=[1])
self.pattern.add_layer(
"paddle.add",
inputs={"x": gen_name(0), "y": gen_name(1)},
outputs=[gen_name(2)])
self.pattern.add_layer(
"paddle.rsqrt",
inputs={"x": gen_name(2)},
outputs=[gen_name(3)])
self.pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(4)])
self.pattern.add_layer(
"paddle.multiply",
inputs={"x": gen_name(3), "y": gen_name(4)},
outputs=[gen_name(5)])
self.pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(6)])
self.pattern.add_layer(
"paddle.multiply",
inputs={"x": gen_name(6), "y": gen_name(5)},
outputs=[gen_name(7)])
self.pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(8)])
self.pattern.add_layer(
"fluid.layers.elementwise_sub",
inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)])
self.pattern.add_layer(
"paddle.multiply",
inputs={"x": "bn-input-0", "y": gen_name(5)},
outputs=[gen_name(10)])
self.pattern.add_layer(
"paddle.add",
inputs={"x": gen_name(10), "y": gen_name(9)},
outputs=[gen_name(11)])
self.pattern.build(inputs={"input-0": "bn-input-0", })
def insert_new_layer(self, graph, parameters, matches):
new_layers, last_layer_id = self.gen_new_layer(matches, parameters, graph)
matches_copy = copy.deepcopy(matches)
for layer_id, layer in matches_copy.items():
for i in range(3):
if layer_id == new_layers[i].id:
matches.pop(new_layers[i].id)
prefix_layers = OrderedDict()
mid_layers = OrderedDict()
suffix_layers = OrderedDict()
is_need_id = False
for layer_id, layer in graph.layers.items():
if is_need_id:
suffix_layers[layer_id] = layer
else:
if layer_id == last_layer_id:
for i in range(3):
mid_layers[new_layers[i].id] = new_layers[i]
is_need_id = True
prefix_layers[layer_id] = layer
prefix_layers.update(mid_layers)
prefix_layers.update(suffix_layers)
graph.layers = prefix_layers
def gen_new_layer(self, matches, parameters, graph):
layer_id_list = list(matches.keys())
layer_id_list.sort(key = int)
for layer_id, layer in matches.items():
if layer.kernel == "paddle.full":
full_layer = layer
out_layer_id = graph.edges_out[layer_id][0]
if matches[out_layer_id].kernel == "paddle.add":
var_layer_id = graph.edges_in[out_layer_id][0]
var_layer = matches[var_layer_id]
if layer.kernel == "paddle.rsqrt":
out_layer_id = graph.edges_out[layer_id][0]
if matches[out_layer_id].kernel == "paddle.multiply":
gamma_layer_id = graph.edges_in[out_layer_id][1]
gamma_layer = matches[gamma_layer_id]
if layer.kernel == "fluid.layers.elementwise_sub":
in_layer_id = graph.edges_in[layer_id][0]
beta_layer = matches[in_layer_id]
in_layer_id = graph.edges_in[layer_id][1]
in_layer_id = graph.edges_in[in_layer_id][0]
mean_layer = matches[in_layer_id]
out_layer_id = graph.edges_out[layer_id][0]
add_layer = matches[out_layer_id]
if layer.kernel == "paddle.multiply":
in_layer_id = graph.edges_in[layer_id][1]
mul_layer = matches[in_layer_id]
if mul_layer.kernel == "paddle.multiply":
in_layer_id = graph.edges_in[layer_id][0]
if in_layer_id not in matches:
input_name = layer.inputs["x"]
transpose0 = PaddleLayer(
id=layer_id_list[-1] + "_1",
kernel="paddle.transpose",
inputs={"x": input_name},
outputs=["{}_transpose_for_bn".format(input_name)],
perm=[0, 3, 1, 2])
bn_name = "merge_bn{}".format(self.bn_index)
self.bn_index += 1
params = parameters[gamma_layer.outputs[0]]
c = params.shape[0]
bn = PaddleLayer(
id=layer_id_list[-1] + "_2",
kernel="paddle.nn.BatchNorm",
inputs={"input": "{}_transpose_for_bn".format(input_name)},
outputs=[bn_name, "{}_bn".format(input_name)],
num_channels=c,
epsilon=full_layer.attrs["fill_value"],
param_attr=string(gamma_layer.outputs[0]),
bias_attr=string(beta_layer.outputs[0]),
moving_mean_name=string(mean_layer.outputs[0]),
moving_variance_name=string(var_layer.outputs[0]),
is_test=True)
transpose1 = PaddleLayer(
id=layer_id_list[-1] + "_3",
kernel="paddle.transpose",
inputs={"x": "{}_bn".format(input_name)},
outputs=add_layer.outputs,
perm=[0, 2, 3, 1])
return [transpose0, bn, transpose1], layer_id_list[-1]
......@@ -15,6 +15,7 @@
from x2paddle.optimizer.pass_manager import PassManager
from x2paddle.optimizer.fusion.dygraph import *
from x2paddle.optimizer.fusion.static import *
from x2paddle.optimizer.elimination.dygraph import *
class GraphOptimizer(object):
def __init__(self, source_frame, paddle_type="dygraph"):
......@@ -30,6 +31,12 @@ class GraphOptimizer(object):
self.passes = ["dygraph_bn_scale_fuse_pass"]
else:
self.passes = ["static_bn_scale_fuse_pass"]
elif source_frame == "tf":
self.passes = [
"dygraph_conv2d_add_fuse_pass",
"dygraph_tf_batchnorm_fuse_pass",
"transpose_eliminate_pass"
]
else:
# TODO
pass
......@@ -37,11 +44,14 @@ class GraphOptimizer(object):
def optimize(self, graph):
for pass_name in self.passes:
pass_ = PassManager.lookup(pass_name)()
while True:
before_len = len(graph.layers)
if pass_name.endswith("_eliminate_pass"):
pass_.apply(graph)
after_len = len(graph.layers)
if before_len == after_len:
break
else:
while True:
before_len = len(graph.layers)
pass_.apply(graph)
after_len = len(graph.layers)
if before_len == after_len:
break
print("{} done!".format(pass_name))
return graph
......@@ -19,6 +19,7 @@ class PatternMatcher(object):
def __init__(self, pattern):
self.pattern = pattern
# matches的每个match是按照拓扑排序组成layer的dict
self.matches = list()
def operate(self, graph, match_kind="topo"):
......@@ -154,7 +155,7 @@ class PatternMatcher(object):
if len(block.layers) > 0:
self.detect_patterns_by_topo(layer.blocks[j])
def detect_patterns_by_edge(self, graph, ignore_list_inputs=True):
def detect_patterns_by_edge(self, graph):
"""当遇见顺序没有强制规定的pattern时使用该方式
"""
......@@ -163,8 +164,8 @@ class PatternMatcher(object):
pattern_ids = list(pattern_id2layers.keys())
pattern_layer_id = pattern_ids[0]
subgraph_id2layers = dict()
graph_layers = dict(list(graph.layers.items())[start_index:])
layer_id = list(graph_layers.keys())[0]
layer_id = list(graph.layers.keys())[start_index]
graph_layers = graph.layers
def update(layer_id, pattern_layer_id):
layer = graph_layers[layer_id]
......@@ -172,14 +173,25 @@ class PatternMatcher(object):
if layer.kernel != pattern_layer.kernel:
return False
subgraph_id2layers[layer_id] = layer
for i, pattern_layer_id_in in enumerate(pattern.edges_in[
pattern_layer_id]):
if pattern_layer_id_in == -1 or ignore_list_inputs:
continue
layer_id_in = graph.edges_in[layer_id][i]
subgraph_ids = list(subgraph_id2layers.keys())
if layer_id_in not in subgraph_ids:
# for k, v in subgraph_id2layers.items():
# print(k)
# print(v.kernel)
# print(v.outputs)
# print("=========")
if pattern.edges_in.get(pattern_layer_id, 0) != 0:
if len(pattern.edges_in[pattern_layer_id]) != \
len(graph.edges_in[layer_id]):
return False
for i, pattern_layer_id_in in enumerate(pattern.edges_in[
pattern_layer_id]):
if pattern_layer_id_in == -1:
continue
if pattern_layer_id_in in pattern_ids:
new_layer_id_in = graph.edges_in[layer_id][i]
if new_layer_id_in in subgraph_id2layers:
continue
update(new_layer_id_in, pattern_layer_id_in)
if pattern.edges_out.get(pattern_layer_id, 0) != 0:
if len(pattern.edges_out[pattern_layer_id]) != \
len(graph.edges_out[layer_id]):
......@@ -188,17 +200,8 @@ class PatternMatcher(object):
pattern_layer_id]):
if pattern_layer_id_out in pattern_ids:
new_layer_id_out = graph.edges_out[layer_id][i]
for j, new_new_layer_id_in in enumerate(
graph.edges_in[new_layer_id_out]):
if new_new_layer_id_in not in subgraph_id2layers:
if ignore_list_inputs:
continue
new_new_pattern_layer_id_in = pattern.edges_in[
pattern_layer_id_out][j]
if new_new_pattern_layer_id_in == -1:
continue
update(new_new_layer_id_in,
new_new_pattern_layer_id_in)
if new_layer_id_out in subgraph_id2layers:
continue
update(new_layer_id_out, pattern_layer_id_out)
while len(subgraph_id2layers) != len(pattern_id2layers):
......@@ -258,6 +261,7 @@ def get_subgraph(prefix_layer_id, suffix_layer_id, graph):
class FuseBase(object):
def __init__(self, graph_type):
self.pattern = PaddleGraph(graph_type=graph_type)
self.patterns = list()
def operate(self, graph, match_kind="topo"):
parameters = graph.parameters
......@@ -267,16 +271,22 @@ class FuseBase(object):
first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph)
self.insert_new_layer(subgraph, parameters, match)
self.delete_inter_layer(graph)
self.delete_layer(graph)
graph.build()
def perform_pattern_matcher(self, graph, match_kind="topo"):
""" 执行模式匹配,找到匹配的子图。
"""
pattern_matcher = PatternMatcher(self.pattern)
self.matches = pattern_matcher.operate(graph, match_kind)
if len(self.patterns) > 0:
self.matches = list()
for pattern in self.patterns:
pattern_matcher = PatternMatcher(pattern)
self.matches.extend(pattern_matcher.operate(graph, match_kind))
else:
pattern_matcher = PatternMatcher(self.pattern)
self.matches = pattern_matcher.operate(graph, match_kind)
def delete_inter_layer(self, graph):
def delete_layer(self, graph):
""" 删除不需要的中间layer及其对应参数。
"""
for match in self.matches:
......@@ -291,3 +301,52 @@ class FuseBase(object):
if layer_id in subgraph.layers:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
subgraph.layers.pop(layer_id)
def delete_layer_with_associated(self, graph, layer_id):
""" 删除不需要的中间layer及其相关连接点。
"""
layer = graph.layers[layer_id]
outputs = graph.edges_out.get(layer_id, [])
inputs = graph.edges_in.get(layer_id, [])
assert len(
inputs) <= 1, "There should be 0 or 1 input for deleted layer."
if len(inputs) == 0:
for out in outputs:
while layer_id in graph.edges_in[out]:
index = graph.edges_in[out].index(layer_id)
del graph.edges_in[out][index]
input_keys = list(graph.layers[out].inputs.keys())
for k in input_keys:
if graph.layers[out].inputs[k] == layer.outputs[0]:
del graph.layers[out].inputs[k]
del graph.layers[layer_id]
if layer_id in graph.edges_in:
del graph.edges_in[layer_id]
if layer_id in graph.edges_out:
del graph.edges_out[layer_id]
return
# 将所有输出layer的输入layer进行替换
for out in outputs:
for i in range(len(graph.edges_in[out])):
if graph.edges_in[out][i] == layer_id:
graph.edges_in[out][i] = inputs[0]
# 将输出layer赋给输入layer的输出
replace_index = graph.edges_out[inputs[0]].index(layer_id)
del graph.edges_out[inputs[0]][replace_index]
for i, out in enumerate(outputs):
graph.edges_out[inputs[0]].insert(replace_index + i, out)
for k, v in graph.layers[out].inputs.items():
if v == layer.outputs[0]:
graph.layers[out].inputs[k] = list(layer.inputs.values())[0]
del graph.layers[layer_id]
if layer_id in graph.edges_out:
del graph.edges_out[layer_id]
if layer_id in graph.edges_in:
del graph.edges_in[layer_id]
......@@ -178,3 +178,4 @@ class BatchNormOpt:
graph.layers[bn.id] = bn
graph.layers[transpose1.id] = transpose1
graph.build()
print("=============")
......@@ -6,11 +6,6 @@ class BiasOpt:
self.conv_layers = [
'fluid.layers.conv2d', 'fluid.layers.conv2d_transpose'
]
self.act_layers = [
'fluid.layers.relu', 'fluid.layers.relu6', 'fluid.layers.sigmoid',
'fluid.layers.exp', 'fluid.layers.tanh', 'fluid.layers.softplus',
'fluid.layers.leaky_relu'
]
def run(self, graph):
print("Optimize: BiasOpt...")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册