提交 3786c539 编写于 作者: S SunAhong1993

fix the transepose

上级 7626cd8a
......@@ -144,9 +144,9 @@ def tf2paddle(model_path,
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
bias_opt.run(program)
batch_norm_opt.run(program)
transpose_opt.run(program)
bias_opt.run(mapper.paddle_graph)
batch_norm_opt.run(mapper.paddle_graph)
transpose_opt.run(mapper.paddle_graph)
mapper.paddle_graph.gen_model(save_dir)
......
......@@ -176,11 +176,12 @@ class TFOpMapper(OpMapper):
x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0]
self.paddle_graph.add_layer(
layer_id = self.paddle_graph.add_layer(
kernel=op_type,
inputs={"x": x.name,
"y": y.name},
outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
def NotEqual(self, node):
x = self.graph.get_node(node.layer.input[0])
......@@ -1241,13 +1242,15 @@ class TFOpMapper(OpMapper):
x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0]
layer_id = self.paddle_graph.add_layer(
"paddle.fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name])
"fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
inputs = {"x": node.name, "y": node.name}
x_shape = node.out_shapes[0]
y_shape = node.out_shapes[0]
layer_id = self.paddle_graph.add_layer(
"paddle.multiply", inputs=inputs, outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
def OneHot(self, node):
input = self.graph.get_node(node.layer.input[0])
......
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .transpose_elimination import Dygraph_TransposeElimination
from .transpose_eliminate_pass import Dygraph_TransposeEliminatePass
\ No newline at end of file
from .transpose_elimination import DygraphTransposeElimination
from .transpose_eliminate_pass import DygraphTransposeEliminatePass
\ No newline at end of file
......@@ -13,21 +13,21 @@
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.elimination.dygraph import Dygraph_TransposeElimination
from x2paddle.optimizer.elimination.dygraph import DygraphTransposeElimination
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Dygraph_TransposeEliminatePass(Pass):
class DygraphTransposeEliminatePass(Pass):
name = "transpose_eliminate_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Dygraph_TransposeElimination()
fuser = DygraphTransposeElimination()
fuser.operate(graph)
# 用于注册
transpose_eliminate_pass = Dygraph_TransposeEliminatePass()
\ No newline at end of file
transpose_eliminate_pass = DygraphTransposeEliminatePass()
\ No newline at end of file
......@@ -13,15 +13,16 @@
# limitations under the License.
import copy
import sys
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class Dygraph_TransposeElimination(FuseBase):
class DygraphTransposeElimination(FuseBase):
def __init__(self):
super(Dygraph_TransposeElimination, self).__init__(graph_type="dygraph")
super(DygraphTransposeElimination, self).__init__(graph_type="dygraph")
self.direct_layers = [
'paddle.nn.ReLU', 'paddle.nn.ReLU6', 'paddle.abs',
'paddle.nn.Sigmoid', 'paddle.exp', 'paddle.rsqrt',
......@@ -54,6 +55,12 @@ class Dygraph_TransposeElimination(FuseBase):
optimized_concat_layers = list()
optimized_elementwise_layers = list()
def get_index(layer):
if layer.kernel.startswith("paddle.nn") and "functional" not in layer.kernel:
return 1
else:
return 0
def strip_transpose(_graph):
layers = copy.deepcopy(_graph.layers)
for layer_id, layer in layers.items():
......@@ -61,7 +68,7 @@ class Dygraph_TransposeElimination(FuseBase):
continue
scanned_layers.add(layer_id)
percent = round(len(scanned_layers) / total_layer_num * 100, 2)
print("\rOptimize Transpose Layers...{}%".format(
sys.stderr.write("\rOptimize Transpose Layers...{}%".format(
percent))
if layer.kernel != "paddle.transpose":
......@@ -84,13 +91,14 @@ class Dygraph_TransposeElimination(FuseBase):
elif _graph.layers[out].kernel in self.elementwise_layers:
propagate_layers.append(out)
elif _graph.layers[out].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[out].kernel.startswith("paddle.nn.") else 0
ouput_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False
break
propagate_layers.append(out)
elif _graph.layers[out].kernel in self.reduce_layers:
if _graph.layers[out].outputs[0] in _graph.outputs:
ouput_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False
break
if _graph.layers[out].attrs.get('keepdim', False):
......@@ -99,7 +107,8 @@ class Dygraph_TransposeElimination(FuseBase):
propagate_layers.append(out)
reduce_layers.append(out)
elif _graph.layers[out].kernel == "paddle.concat":
if _graph.layers[out].outputs[0] in _graph.outputs:
ouput_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False
break
propagate_layers.append(out)
......@@ -121,20 +130,22 @@ class Dygraph_TransposeElimination(FuseBase):
transpose_layers.append(out)
elif _graph.layers[
out].kernel in self.elementwise_layers:
if _graph.layers[out].outputs[0] in _graph.outputs:
output_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False
break
if out not in visited_layers:
propagate_layers.append(out)
elif _graph.layers[out].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[out].kernel.startswith("paddle.nn.") else 0
if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
output_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False
break
if out not in visited_layers:
propagate_layers.append(out)
elif _graph.layers[out].kernel in self.reduce_layers:
if _graph.layers[out].outputs[0] in _graph.outputs:
output_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False
break
if _graph.layers[out].attrs.get('keepdim',
......@@ -145,7 +156,8 @@ class Dygraph_TransposeElimination(FuseBase):
propagate_layers.append(out)
reduce_layers.append(out)
elif _graph.layers[out].kernel == "paddle.concat":
if _graph.layers[out].outputs[0] in _graph.outputs:
output_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False
break
if out not in visited_layers:
......@@ -162,14 +174,15 @@ class Dygraph_TransposeElimination(FuseBase):
current_id].input_shapes['x']
y_shape = _graph.layers[
current_id].input_shapes['y']
output_index = get_index(_graph.layers[ipt])
if _graph.layers[ipt].outputs[
0] == _graph.layers[current_id].inputs[
output_index] == _graph.layers[current_id].inputs[
'x']:
if len(x_shape) <= 1:
elementwise_layers.append(current_id)
continue
elif _graph.layers[ipt].outputs[
0] == _graph.layers[current_id].inputs[
output_index] == _graph.layers[current_id].inputs[
'y']:
if len(y_shape) <= 1:
elementwise_layers.append(current_id)
......@@ -181,6 +194,7 @@ class Dygraph_TransposeElimination(FuseBase):
except Exception as e:
can_be_optimized = False
break
output_index = get_index(_graph.layers[ipt])
if _graph.layers[
ipt].kernel == "paddle.transpose":
if _graph.layers[ipt].attrs["perm"] != [0, 2, 3, 1]:
......@@ -190,20 +204,19 @@ class Dygraph_TransposeElimination(FuseBase):
transpose_layers.append(ipt)
elif _graph.layers[
ipt].kernel in self.elementwise_layers:
if _graph.layers[ipt].outputs[0] in _graph.outputs:
if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
can_be_optimized = False
break
if ipt not in visited_layers:
propagate_layers.append(ipt)
elif _graph.layers[ipt].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[ipt].kernel.startswith("paddle.nn.") else 0
if _graph.layers[ipt].outputs[ouput_index] in _graph.outputs:
if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
can_be_optimized = False
break
if ipt not in visited_layers:
propagate_layers.append(ipt)
elif _graph.layers[ipt].kernel in self.reduce_layers:
if _graph.layers[ipt].outputs[0] in _graph.outputs:
if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
can_be_optimized = False
break
if _graph.layers[ipt].attrs.get('keepdim',
......@@ -214,7 +227,7 @@ class Dygraph_TransposeElimination(FuseBase):
propagate_layers.append(ipt)
reduce_layers.append(ipt)
elif _graph.layers[ipt].kernel == "paddle.concat":
if _graph.layers[ipt].outputs[0] in _graph.outputs:
if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
can_be_optimized = False
break
if ipt not in visited_layers:
......@@ -231,7 +244,8 @@ class Dygraph_TransposeElimination(FuseBase):
transpose_layers.append(layer_id)
transpose_layers = list(set(transpose_layers))
for l in transpose_layers:
if graph.layers[l].outputs[0] in graph.outputs:
output_index = get_index(graph.layers[l])
if graph.layers[l].outputs[output_index] in graph.outputs:
can_be_optimized = False
break
if not can_be_optimized:
......@@ -254,6 +268,7 @@ class Dygraph_TransposeElimination(FuseBase):
while strip_transpose(opt_graph):
pass
for layer_id in list(set(optimized_transpose_layers)):
self.delete_layer_with_associated(graph, layer_id)
for layer_id in list(set(optimized_reduce_layers)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册