提交 4d7134cc 编写于 作者: S SunAhong1993

fix the tf

上级 c48cf6d8
...@@ -172,7 +172,7 @@ class TFGraph(Graph): ...@@ -172,7 +172,7 @@ class TFGraph(Graph):
self._remove_isolated_node() self._remove_isolated_node()
self._optimize_dialiation_conv() self._optimize_dialiation_conv()
self._remove_identity_node() self._remove_identity_node()
# self._remove_cast_node() self._remove_cast_node()
def get_node(self, node_name, copy=False): def get_node(self, node_name, copy=False):
......
...@@ -178,13 +178,13 @@ class DygraphTransposeElimination(FuseBase): ...@@ -178,13 +178,13 @@ class DygraphTransposeElimination(FuseBase):
if _graph.layers[ipt].outputs[ if _graph.layers[ipt].outputs[
output_index] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'x']: 'x']:
if len(x_shape) <= 1: if list(x_shape)==[1] or len(x_shape) < 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
elif _graph.layers[ipt].outputs[ elif _graph.layers[ipt].outputs[
output_index] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'y']: 'y']:
if len(y_shape) <= 1: if list(y_shape)==[1] or len(y_shape) < 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
else: else:
...@@ -279,11 +279,6 @@ class DygraphTransposeElimination(FuseBase): ...@@ -279,11 +279,6 @@ class DygraphTransposeElimination(FuseBase):
for layer_id in list(set(optimized_concat_layers)): for layer_id in list(set(optimized_concat_layers)):
axis = graph.layers[layer_id].attrs.get('axis', 0) axis = graph.layers[layer_id].attrs.get('axis', 0)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
for layer_id in list(set(optimized_elementwise_layers)):
axis = graph.layers[layer_id].attrs.get('axis', -1)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
if graph.layers[layer_id].kernel == "paddle.add":
graph.layers[layer_id].kernel = "fluid.layers.elementwise_add"
current_transpose_num = self.get_transpose_num(graph) current_transpose_num = self.get_transpose_num(graph)
print( print(
......
...@@ -178,13 +178,13 @@ class StaticTransposeElimination(FuseBase): ...@@ -178,13 +178,13 @@ class StaticTransposeElimination(FuseBase):
if _graph.layers[ipt].outputs[ if _graph.layers[ipt].outputs[
output_index] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'x']: 'x']:
if len(x_shape) <= 1: if list(x_shape)==[1] or len(x_shape) < 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
elif _graph.layers[ipt].outputs[ elif _graph.layers[ipt].outputs[
output_index] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'y']: 'y']:
if len(y_shape) <= 1: if list(y_shape)==[1] or len(y_shape) < 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
else: else:
...@@ -279,11 +279,6 @@ class StaticTransposeElimination(FuseBase): ...@@ -279,11 +279,6 @@ class StaticTransposeElimination(FuseBase):
for layer_id in list(set(optimized_concat_layers)): for layer_id in list(set(optimized_concat_layers)):
axis = graph.layers[layer_id].attrs.get('axis', 0) axis = graph.layers[layer_id].attrs.get('axis', 0)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
for layer_id in list(set(optimized_elementwise_layers)):
axis = graph.layers[layer_id].attrs.get('axis', -1)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
if graph.layers[layer_id].kernel == "paddle.add":
graph.layers[layer_id].kernel = "fluid.layers.elementwise_add"
current_transpose_num = self.get_transpose_num(graph) current_transpose_num = self.get_transpose_num(graph)
print( print(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册