From 4d7134cc153911d886520fd99baeb76bb6120b76 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Thu, 17 Dec 2020 19:19:10 +0800 Subject: [PATCH] fix the tf --- x2paddle/decoder/tf_decoder.py | 2 +- .../elimination/dygraph/transpose_elimination.py | 9 ++------- .../elimination/static/transpose_elimination.py | 9 ++------- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index eb687ab..c8380bd 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -172,7 +172,7 @@ class TFGraph(Graph): self._remove_isolated_node() self._optimize_dialiation_conv() self._remove_identity_node() -# self._remove_cast_node() + self._remove_cast_node() def get_node(self, node_name, copy=False): diff --git a/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py b/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py index 22c59d7..4e27e4d 100644 --- a/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py +++ b/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py @@ -178,13 +178,13 @@ class DygraphTransposeElimination(FuseBase): if _graph.layers[ipt].outputs[ output_index] == _graph.layers[current_id].inputs[ 'x']: - if len(x_shape) <= 1: + if list(x_shape)==[1] or len(x_shape) < 1: elementwise_layers.append(current_id) continue elif _graph.layers[ipt].outputs[ output_index] == _graph.layers[current_id].inputs[ 'y']: - if len(y_shape) <= 1: + if list(y_shape)==[1] or len(y_shape) < 1: elementwise_layers.append(current_id) continue else: @@ -279,11 +279,6 @@ class DygraphTransposeElimination(FuseBase): for layer_id in list(set(optimized_concat_layers)): axis = graph.layers[layer_id].attrs.get('axis', 0) graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] - for layer_id in list(set(optimized_elementwise_layers)): - axis = graph.layers[layer_id].attrs.get('axis', -1) - graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] - if graph.layers[layer_id].kernel == "paddle.add": - graph.layers[layer_id].kernel = "fluid.layers.elementwise_add" current_transpose_num = self.get_transpose_num(graph) print( diff --git a/x2paddle/optimizer/elimination/static/transpose_elimination.py b/x2paddle/optimizer/elimination/static/transpose_elimination.py index 66819f7..ee68fbd 100644 --- a/x2paddle/optimizer/elimination/static/transpose_elimination.py +++ b/x2paddle/optimizer/elimination/static/transpose_elimination.py @@ -178,13 +178,13 @@ class StaticTransposeElimination(FuseBase): if _graph.layers[ipt].outputs[ output_index] == _graph.layers[current_id].inputs[ 'x']: - if len(x_shape) <= 1: + if list(x_shape)==[1] or len(x_shape) < 1: elementwise_layers.append(current_id) continue elif _graph.layers[ipt].outputs[ output_index] == _graph.layers[current_id].inputs[ 'y']: - if len(y_shape) <= 1: + if list(y_shape)==[1] or len(y_shape) < 1: elementwise_layers.append(current_id) continue else: @@ -279,11 +279,6 @@ class StaticTransposeElimination(FuseBase): for layer_id in list(set(optimized_concat_layers)): axis = graph.layers[layer_id].attrs.get('axis', 0) graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] - for layer_id in list(set(optimized_elementwise_layers)): - axis = graph.layers[layer_id].attrs.get('axis', -1) - graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] - if graph.layers[layer_id].kernel == "paddle.add": - graph.layers[layer_id].kernel = "fluid.layers.elementwise_add" current_transpose_num = self.get_transpose_num(graph) print( -- GitLab