提交 3786c539 编写于 作者: S SunAhong1993

fix the transepose

上级 7626cd8a
...@@ -144,9 +144,9 @@ def tf2paddle(model_path, ...@@ -144,9 +144,9 @@ def tf2paddle(model_path,
bias_opt = BiasOpt() bias_opt = BiasOpt()
transpose_opt = TransposeOpt() transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt() batch_norm_opt = BatchNormOpt()
bias_opt.run(program) bias_opt.run(mapper.paddle_graph)
batch_norm_opt.run(program) batch_norm_opt.run(mapper.paddle_graph)
transpose_opt.run(program) transpose_opt.run(mapper.paddle_graph)
mapper.paddle_graph.gen_model(save_dir) mapper.paddle_graph.gen_model(save_dir)
......
...@@ -176,11 +176,12 @@ class TFOpMapper(OpMapper): ...@@ -176,11 +176,12 @@ class TFOpMapper(OpMapper):
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0] y_shape = y.out_shapes[0]
self.paddle_graph.add_layer( layer_id = self.paddle_graph.add_layer(
kernel=op_type, kernel=op_type,
inputs={"x": x.name, inputs={"x": x.name,
"y": y.name}, "y": y.name},
outputs=[node.name]) outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
def NotEqual(self, node): def NotEqual(self, node):
x = self.graph.get_node(node.layer.input[0]) x = self.graph.get_node(node.layer.input[0])
...@@ -1241,13 +1242,15 @@ class TFOpMapper(OpMapper): ...@@ -1241,13 +1242,15 @@ class TFOpMapper(OpMapper):
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0] y_shape = y.out_shapes[0]
layer_id = self.paddle_graph.add_layer( layer_id = self.paddle_graph.add_layer(
"paddle.fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name]) "fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
inputs = {"x": node.name, "y": node.name} inputs = {"x": node.name, "y": node.name}
x_shape = node.out_shapes[0] x_shape = node.out_shapes[0]
y_shape = node.out_shapes[0] y_shape = node.out_shapes[0]
layer_id = self.paddle_graph.add_layer( layer_id = self.paddle_graph.add_layer(
"paddle.multiply", inputs=inputs, outputs=[node.name]) "paddle.multiply", inputs=inputs, outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
def OneHot(self, node): def OneHot(self, node):
input = self.graph.get_node(node.layer.input[0]) input = self.graph.get_node(node.layer.input[0])
......
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .transpose_elimination import Dygraph_TransposeElimination from .transpose_elimination import DygraphTransposeElimination
from .transpose_eliminate_pass import Dygraph_TransposeEliminatePass from .transpose_eliminate_pass import DygraphTransposeEliminatePass
\ No newline at end of file \ No newline at end of file
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.elimination.dygraph import Dygraph_TransposeElimination from x2paddle.optimizer.elimination.dygraph import DygraphTransposeElimination
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_TransposeEliminatePass(Pass): class DygraphTransposeEliminatePass(Pass):
name = "transpose_eliminate_pass" name = "transpose_eliminate_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_TransposeElimination() fuser = DygraphTransposeElimination()
fuser.operate(graph) fuser.operate(graph)
# 用于注册 # 用于注册
transpose_eliminate_pass = Dygraph_TransposeEliminatePass() transpose_eliminate_pass = DygraphTransposeEliminatePass()
\ No newline at end of file \ No newline at end of file
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
# limitations under the License. # limitations under the License.
import copy import copy
import sys
import numpy as np import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_TransposeElimination(FuseBase): class DygraphTransposeElimination(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_TransposeElimination, self).__init__(graph_type="dygraph") super(DygraphTransposeElimination, self).__init__(graph_type="dygraph")
self.direct_layers = [ self.direct_layers = [
'paddle.nn.ReLU', 'paddle.nn.ReLU6', 'paddle.abs', 'paddle.nn.ReLU', 'paddle.nn.ReLU6', 'paddle.abs',
'paddle.nn.Sigmoid', 'paddle.exp', 'paddle.rsqrt', 'paddle.nn.Sigmoid', 'paddle.exp', 'paddle.rsqrt',
...@@ -53,6 +54,12 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -53,6 +54,12 @@ class Dygraph_TransposeElimination(FuseBase):
optimized_reduce_layers = list() optimized_reduce_layers = list()
optimized_concat_layers = list() optimized_concat_layers = list()
optimized_elementwise_layers = list() optimized_elementwise_layers = list()
def get_index(layer):
if layer.kernel.startswith("paddle.nn") and "functional" not in layer.kernel:
return 1
else:
return 0
def strip_transpose(_graph): def strip_transpose(_graph):
layers = copy.deepcopy(_graph.layers) layers = copy.deepcopy(_graph.layers)
...@@ -61,7 +68,7 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -61,7 +68,7 @@ class Dygraph_TransposeElimination(FuseBase):
continue continue
scanned_layers.add(layer_id) scanned_layers.add(layer_id)
percent = round(len(scanned_layers) / total_layer_num * 100, 2) percent = round(len(scanned_layers) / total_layer_num * 100, 2)
print("\rOptimize Transpose Layers...{}%".format( sys.stderr.write("\rOptimize Transpose Layers...{}%".format(
percent)) percent))
if layer.kernel != "paddle.transpose": if layer.kernel != "paddle.transpose":
...@@ -84,13 +91,14 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -84,13 +91,14 @@ class Dygraph_TransposeElimination(FuseBase):
elif _graph.layers[out].kernel in self.elementwise_layers: elif _graph.layers[out].kernel in self.elementwise_layers:
propagate_layers.append(out) propagate_layers.append(out)
elif _graph.layers[out].kernel in self.direct_layers: elif _graph.layers[out].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[out].kernel.startswith("paddle.nn.") else 0 ouput_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[ouput_index] in _graph.outputs: if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
propagate_layers.append(out) propagate_layers.append(out)
elif _graph.layers[out].kernel in self.reduce_layers: elif _graph.layers[out].kernel in self.reduce_layers:
if _graph.layers[out].outputs[0] in _graph.outputs: ouput_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if _graph.layers[out].attrs.get('keepdim', False): if _graph.layers[out].attrs.get('keepdim', False):
...@@ -99,7 +107,8 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -99,7 +107,8 @@ class Dygraph_TransposeElimination(FuseBase):
propagate_layers.append(out) propagate_layers.append(out)
reduce_layers.append(out) reduce_layers.append(out)
elif _graph.layers[out].kernel == "paddle.concat": elif _graph.layers[out].kernel == "paddle.concat":
if _graph.layers[out].outputs[0] in _graph.outputs: ouput_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
propagate_layers.append(out) propagate_layers.append(out)
...@@ -121,20 +130,22 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -121,20 +130,22 @@ class Dygraph_TransposeElimination(FuseBase):
transpose_layers.append(out) transpose_layers.append(out)
elif _graph.layers[ elif _graph.layers[
out].kernel in self.elementwise_layers: out].kernel in self.elementwise_layers:
if _graph.layers[out].outputs[0] in _graph.outputs: output_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if out not in visited_layers: if out not in visited_layers:
propagate_layers.append(out) propagate_layers.append(out)
elif _graph.layers[out].kernel in self.direct_layers: elif _graph.layers[out].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[out].kernel.startswith("paddle.nn.") else 0 output_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[ouput_index] in _graph.outputs: if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if out not in visited_layers: if out not in visited_layers:
propagate_layers.append(out) propagate_layers.append(out)
elif _graph.layers[out].kernel in self.reduce_layers: elif _graph.layers[out].kernel in self.reduce_layers:
if _graph.layers[out].outputs[0] in _graph.outputs: output_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if _graph.layers[out].attrs.get('keepdim', if _graph.layers[out].attrs.get('keepdim',
...@@ -145,7 +156,8 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -145,7 +156,8 @@ class Dygraph_TransposeElimination(FuseBase):
propagate_layers.append(out) propagate_layers.append(out)
reduce_layers.append(out) reduce_layers.append(out)
elif _graph.layers[out].kernel == "paddle.concat": elif _graph.layers[out].kernel == "paddle.concat":
if _graph.layers[out].outputs[0] in _graph.outputs: output_index = get_index(_graph.layers[out])
if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if out not in visited_layers: if out not in visited_layers:
...@@ -162,14 +174,15 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -162,14 +174,15 @@ class Dygraph_TransposeElimination(FuseBase):
current_id].input_shapes['x'] current_id].input_shapes['x']
y_shape = _graph.layers[ y_shape = _graph.layers[
current_id].input_shapes['y'] current_id].input_shapes['y']
output_index = get_index(_graph.layers[ipt])
if _graph.layers[ipt].outputs[ if _graph.layers[ipt].outputs[
0] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'x']: 'x']:
if len(x_shape) <= 1: if len(x_shape) <= 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
elif _graph.layers[ipt].outputs[ elif _graph.layers[ipt].outputs[
0] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'y']: 'y']:
if len(y_shape) <= 1: if len(y_shape) <= 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
...@@ -181,6 +194,7 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -181,6 +194,7 @@ class Dygraph_TransposeElimination(FuseBase):
except Exception as e: except Exception as e:
can_be_optimized = False can_be_optimized = False
break break
output_index = get_index(_graph.layers[ipt])
if _graph.layers[ if _graph.layers[
ipt].kernel == "paddle.transpose": ipt].kernel == "paddle.transpose":
if _graph.layers[ipt].attrs["perm"] != [0, 2, 3, 1]: if _graph.layers[ipt].attrs["perm"] != [0, 2, 3, 1]:
...@@ -190,20 +204,19 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -190,20 +204,19 @@ class Dygraph_TransposeElimination(FuseBase):
transpose_layers.append(ipt) transpose_layers.append(ipt)
elif _graph.layers[ elif _graph.layers[
ipt].kernel in self.elementwise_layers: ipt].kernel in self.elementwise_layers:
if _graph.layers[ipt].outputs[0] in _graph.outputs: if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if ipt not in visited_layers: if ipt not in visited_layers:
propagate_layers.append(ipt) propagate_layers.append(ipt)
elif _graph.layers[ipt].kernel in self.direct_layers: elif _graph.layers[ipt].kernel in self.direct_layers:
ouput_index = 1 if _graph.layers[ipt].kernel.startswith("paddle.nn.") else 0 if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
if _graph.layers[ipt].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if ipt not in visited_layers: if ipt not in visited_layers:
propagate_layers.append(ipt) propagate_layers.append(ipt)
elif _graph.layers[ipt].kernel in self.reduce_layers: elif _graph.layers[ipt].kernel in self.reduce_layers:
if _graph.layers[ipt].outputs[0] in _graph.outputs: if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if _graph.layers[ipt].attrs.get('keepdim', if _graph.layers[ipt].attrs.get('keepdim',
...@@ -214,7 +227,7 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -214,7 +227,7 @@ class Dygraph_TransposeElimination(FuseBase):
propagate_layers.append(ipt) propagate_layers.append(ipt)
reduce_layers.append(ipt) reduce_layers.append(ipt)
elif _graph.layers[ipt].kernel == "paddle.concat": elif _graph.layers[ipt].kernel == "paddle.concat":
if _graph.layers[ipt].outputs[0] in _graph.outputs: if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if ipt not in visited_layers: if ipt not in visited_layers:
...@@ -231,7 +244,8 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -231,7 +244,8 @@ class Dygraph_TransposeElimination(FuseBase):
transpose_layers.append(layer_id) transpose_layers.append(layer_id)
transpose_layers = list(set(transpose_layers)) transpose_layers = list(set(transpose_layers))
for l in transpose_layers: for l in transpose_layers:
if graph.layers[l].outputs[0] in graph.outputs: output_index = get_index(graph.layers[l])
if graph.layers[l].outputs[output_index] in graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if not can_be_optimized: if not can_be_optimized:
...@@ -254,6 +268,7 @@ class Dygraph_TransposeElimination(FuseBase): ...@@ -254,6 +268,7 @@ class Dygraph_TransposeElimination(FuseBase):
while strip_transpose(opt_graph): while strip_transpose(opt_graph):
pass pass
for layer_id in list(set(optimized_transpose_layers)): for layer_id in list(set(optimized_transpose_layers)):
self.delete_layer_with_associated(graph, layer_id) self.delete_layer_with_associated(graph, layer_id)
for layer_id in list(set(optimized_reduce_layers)): for layer_id in list(set(optimized_reduce_layers)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册