未验证 提交 bdc67376 编写于 作者: J Jason 提交者: GitHub

Merge pull request #163 from PaddlePaddle/develop

pull
...@@ -129,9 +129,12 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto): ...@@ -129,9 +129,12 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto):
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
import google.protobuf as gpb import google.protobuf as gpb
ver_str = gpb.__version__.replace('.', '') ver_part = gpb.__version__.split('.')
ver_int = int(ver_str[0:2]) version_satisfy = False
assert ver_int >= 36, 'The version of protobuf must be larger than 3.6.0!' if (int(ver_part[0]) == 3 and int(ver_part[1]) >= 6) \
or (int(ver_part[0]) > 3):
version_satisfy = True
assert version_satisfy, 'google.protobuf >= 3.6.0 is required'
print("Now translating model from caffe to paddle.") print("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto, weight, caffe_proto) model = CaffeDecoder(proto, weight, caffe_proto)
mapper = CaffeOpMapper(model) mapper = CaffeOpMapper(model)
......
...@@ -8,13 +8,7 @@ def shufflechannel_shape(input_shape): ...@@ -8,13 +8,7 @@ def shufflechannel_shape(input_shape):
def shufflechannel_layer(inputs, group=None, input_shape=None, name=None): def shufflechannel_layer(inputs, group=None, input_shape=None, name=None):
input = inputs[0] input = inputs[0]
c_fm = fluid.layers.split(input, num_or_sections=input_shape[0][1], dim=1) out = fluid.layers.shuffle_channel(x=input, group=group)
size = int(input_shape[0][1] / group)
new_c_fm = []
for i in range(size):
for j in range(group):
new_c_fm.append(c_fm[j * size + i])
out = fluid.layers.concat(new_c_fm, axis=1)
return out return out
......
...@@ -363,7 +363,7 @@ class CaffeOpMapper(OpMapper): ...@@ -363,7 +363,7 @@ class CaffeOpMapper(OpMapper):
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
attr = { attr = {
'n': params.local_size, 'n': params.local_size,
'k': 1.0, 'k': params.k,
'alpha': alpha, 'alpha': alpha,
'beta': params.beta, 'beta': params.beta,
'name': string(node.layer_name) 'name': string(node.layer_name)
...@@ -450,35 +450,19 @@ class CaffeOpMapper(OpMapper): ...@@ -450,35 +450,19 @@ class CaffeOpMapper(OpMapper):
slice_dim = params.slice_dim slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1: if slice_dim != 1 and axis == 1:
axis = slice_dim axis = slice_dim
points = list(params.slice_point) output_shape = node.output_shape
sections_list = []
if len(points) == 0: for s in output_shape:
dims = node.input_shape[0][axis] sections_list.append(s[axis])
assert dims % top_len == 0, "the parameter of Slice is wrong" attr = {
part = dims / top_len 'num_or_sections': sections_list,
t = part 'dim': axis,
while t < dims: 'name': string(node.layer_name)
points.append(int(t)) }
t += part node.fluid_code.add_layer("split",
maxint32 = 2147483647 inputs=input,
points = [0] + points output=node.layer_name,
points.append(maxint32) param_attr=attr)
i = 0
node.fluid_code.add_note('{} = []'.format(node.layer_name))
for i in range(len(points)):
attr = {
'axes': [axis],
'starts': [points[i]],
'ends': [points[i + 1]]
}
node.fluid_code.add_layer("slice",
inputs=input,
output=node.layer_name + '_' + str(i),
param_attr=attr)
node.fluid_code.add_note('{}.append({})'.format(
node.layer_name, node.layer_name + '_' + str(i)))
if i == len(points) - 2:
break
def Concat(self, node): def Concat(self, node):
assert len( assert len(
...@@ -649,7 +633,8 @@ class CaffeOpMapper(OpMapper): ...@@ -649,7 +633,8 @@ class CaffeOpMapper(OpMapper):
]).astype('float32') ]).astype('float32')
scale = 0 scale = 0
else: else:
node.data = [np.squeeze(i) for i in node.data]
node.data = [np.squeeze(i).astype('float32') for i in node.data]
mean, variance, scale = node.data mean, variance, scale = node.data
# Prescale the stats # Prescale the stats
scaling_factor = 1.0 / scale if scale != 0 else 0 scaling_factor = 1.0 / scale if scale != 0 else 0
...@@ -684,8 +669,10 @@ class CaffeOpMapper(OpMapper): ...@@ -684,8 +669,10 @@ class CaffeOpMapper(OpMapper):
input_c, input_c,
]).astype('float32') ]).astype('float32')
else: else:
self.weights[node.layer_name + '_scale'] = np.squeeze(node.data[0]) self.weights[node.layer_name + '_scale'] = np.squeeze(
self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1]) node.data[0]).astype('float32')
self.weights[node.layer_name + '_offset'] = np.squeeze(
node.data[1]).astype('float32')
params = node.layer.scale_param params = node.layer.scale_param
axis = params.axis axis = params.axis
num_axes = params.num_axes num_axes = params.num_axes
......
...@@ -168,7 +168,11 @@ class TFOpMapper(OpMapper): ...@@ -168,7 +168,11 @@ class TFOpMapper(OpMapper):
x_input = y x_input = y
y_input = x y_input = x
x_shape = y.out_shapes[0] x_shape = y.out_shapes[0]
if len(x_shape) == 0:
x_shape = [1]
y_shape = x.out_shapes[0] y_shape = x.out_shapes[0]
if len(y_shape) == 0:
y_shape = [1]
else: else:
if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[ if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[
0] == y_shape[-1] and y_shape.count(-1) < 1: 0] == y_shape[-1] and y_shape.count(-1) < 1:
......
...@@ -121,10 +121,29 @@ class TFOpMapperNHWC(OpMapper): ...@@ -121,10 +121,29 @@ class TFOpMapperNHWC(OpMapper):
pd_param_name = list(param.values())[0] pd_param_name = list(param.values())[0]
tf_param = node.get_attr(tf_param_name) tf_param = node.get_attr(tf_param_name)
attr[pd_param_name] = tf_param attr[pd_param_name] = tf_param
node.fluid_code.add_layer(op_info[0],
inputs=input, if len(input.out_shapes[0]) == 4 and op_info[0] != 'shape':
output=node, attr1 = {"perm": [0, 3, 1, 2]}
param_attr=attr) node.fluid_code.add_layer('transpose',
inputs=input,
output=node,
param_attr=attr1)
input = node
node.fluid_code.add_layer(op_info[0],
inputs=input,
output=node,
param_attr=attr)
input = node
attr2 = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer('transpose',
inputs=input,
output=node,
param_attr=attr2)
else:
node.fluid_code.add_layer(op_info[0],
inputs=input,
output=node,
param_attr=attr)
def elementwise_map(self, node): def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops assert node.layer_type in self.elementwise_ops
...@@ -149,7 +168,11 @@ class TFOpMapperNHWC(OpMapper): ...@@ -149,7 +168,11 @@ class TFOpMapperNHWC(OpMapper):
x_input = y x_input = y
y_input = x y_input = x
x_shape = y.out_shapes[0] x_shape = y.out_shapes[0]
if len(x_shape) == 0:
x_shape = [1]
y_shape = x.out_shapes[0] y_shape = x.out_shapes[0]
if len(y_shape) == 0:
y_shape = [1]
else: else:
raise Exception("Unexpected situation happend") raise Exception("Unexpected situation happend")
...@@ -193,11 +216,30 @@ class TFOpMapperNHWC(OpMapper): ...@@ -193,11 +216,30 @@ class TFOpMapperNHWC(OpMapper):
output="y_tmp", output="y_tmp",
param_attr=attr) param_attr=attr)
y_input = "y_tmp" y_input = "y_tmp"
inputs = {"x": x_input, "y": y_input} if len(x_shape) == 4 and len(y_shape) == 4:
node.fluid_code.add_layer(op_type, node.fluid_code.add_layer("transpose",
inputs=inputs, inputs=x_input,
output=node, output=x_input,
param_attr=None) param_attr={'perm': [0, 3, 1, 2]})
node.fluid_code.add_layer("transpose",
inputs=y_input,
output=y_input,
param_attr={'perm': [0, 3, 1, 2]})
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr={'perm': [0, 2, 3, 1]})
else:
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
def Placeholder(self, node): def Placeholder(self, node):
shape = node.out_shapes[0] shape = node.out_shapes[0]
...@@ -978,9 +1020,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -978,9 +1020,7 @@ class TFOpMapperNHWC(OpMapper):
if pad_mode == "SAME": if pad_mode == "SAME":
if node.tf_data_format == "NHWC": if node.tf_data_format == "NHWC":
print(out_shape)
out_shape = [out_shape[i] for i in [0, 3, 1, 2]] out_shape = [out_shape[i] for i in [0, 3, 1, 2]]
print(out_shape)
for i in range(4): for i in range(4):
if out_shape[i] < 0: if out_shape[i] < 0:
out_shape[i] = 999999 out_shape[i] = 999999
......
...@@ -232,84 +232,35 @@ class TFOptimizer(object): ...@@ -232,84 +232,35 @@ class TFOptimizer(object):
'act'] 'act']
node.fluid_code.clear() node.fluid_code.clear()
self.graph.remove_node(node.layer_name) self.graph.remove_node(node.layer_name)
self.graph.identity_map[node.layer_name] = input.layer_name
def remove_transpose(self): def remove_transpose(self):
graph_copy = cp.deepcopy(self.graph) graph_copy = cp.deepcopy(self.graph)
nhwc_insensitive_ops = [ nhwc_insensitive_ops = [
'Relu', 'Relu6', 'Abs', 'Sigmoid', 'Exp', 'Rsqrt', 'swish_f32', 'Relu', 'Relu6', 'Abs', 'Sigmoid', 'Exp', 'Rsqrt', 'swish_f32',
'LeakyRelu', 'Cast' 'LeakyRelu', 'Cast', 'Tanh'
] ]
elementwise_ops = [ elementwise_ops = [
'Sub', 'Add', 'RealDiv', 'Maximum', 'Mul', 'FloorDiv', 'Sub', 'Add', 'RealDiv', 'Maximum', 'Mul', 'FloorDiv',
'GreaterEqual' 'GreaterEqual'
] ]
for node_name in self.graph.topo_sort:
node = graph_copy.get_node(node_name)
if node is None:
continue
if node.layer_type in nhwc_insensitive_ops:
graph_copy.remove_node(node_name)
optimize_ops = [ optimize_ops = [
'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative', 'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative',
'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor', 'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor',
'ResizeBilinear', "Placeholder" 'ResizeBilinear', "Placeholder"
] ]
can_be_optimized_ops = [
'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative',
'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor',
'ResizeBilinear', "Placeholder", 'Relu', 'Relu6', 'Abs', 'Sigmoid',
'Exp', 'Rsqrt', 'swish_f32', 'LeakyRelu', 'Cast', 'Tanh'
]
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
node = graph_copy.get_node(node_name) node = graph_copy.get_node(node_name)
if node is None: if node is None:
continue continue
if node.layer_type in elementwise_ops: if node.layer_type in can_be_optimized_ops:
is_nhwc = True
for in_name in node.inputs:
in_node = graph_copy.get_node(in_name)
if hasattr(in_node, "is_nhwc"):
if not in_node.is_nhwc:
is_nhwc = False
else:
if len(in_node.fluid_code.layers) < 2:
is_nhwc = False
continue
if in_node.fluid_code.layers[
-1].op != "transpose" or in_node.fluid_code.layers[
-1].param_attr["perm"] != [0, 2, 3, 1]:
is_nhwc = False
continue
node.is_nhwc = is_nhwc
for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[-1 * i - 1]
node = graph_copy.get_node(node_name)
if node is None:
continue
if node.layer_type in elementwise_ops:
can_be_removed = True
if len(node.fluid_code.layers) > 1:
can_be_removed = False
if not node.is_nhwc:
can_be_removed = False
for out_name in node.outputs:
out_node = graph_copy.get_node(out_name)
if hasattr(out_node, "is_nhwc"):
if not out_node.is_nhwc:
can_be_removed = False
else:
if len(out_node.fluid_code.layers) < 2:
can_be_removed = False
break
if out_node.fluid_code.layers[
0].op != "transpose" or out_node.fluid_code.layers[
0].param_attr["perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
node.can_be_removed = can_be_removed
for node_name in self.graph.topo_sort:
node = graph_copy.get_node(node_name)
if node is None:
continue
if node.layer_type in optimize_ops:
if node.fluid_code.layers[ if node.fluid_code.layers[
-1].op != "transpose" or node.fluid_code.layers[ -1].op != "transpose" or node.fluid_code.layers[
-1].param_attr["perm"] != [0, 2, 3, 1]: -1].param_attr["perm"] != [0, 2, 3, 1]:
...@@ -327,6 +278,9 @@ class TFOptimizer(object): ...@@ -327,6 +278,9 @@ class TFOptimizer(object):
0].param_attr["perm"] != [0, 3, 1, 2]: 0].param_attr["perm"] != [0, 3, 1, 2]:
can_be_removed = False can_be_removed = False
break break
elif out_node.layer_type in elementwise_ops:
can_be_removed = False
break
if can_be_removed and len(node.fluid_code.layers) > 1: if can_be_removed and len(node.fluid_code.layers) > 1:
true_node = self.graph.get_node(node_name) true_node = self.graph.get_node(node_name)
if true_node.layer_type == "Placeholder": if true_node.layer_type == "Placeholder":
...@@ -346,8 +300,6 @@ class TFOptimizer(object): ...@@ -346,8 +300,6 @@ class TFOptimizer(object):
del true_node.fluid_code.layers[-1] del true_node.fluid_code.layers[-1]
for out_name in output_names: for out_name in output_names:
out_node = self.graph.get_node(out_name) out_node = self.graph.get_node(out_name)
if out_node.layer_type in elementwise_ops:
continue
out_node.fluid_code.layers[ out_node.fluid_code.layers[
1].inputs = out_node.fluid_code.layers[0].inputs 1].inputs = out_node.fluid_code.layers[0].inputs
del out_node.fluid_code.layers[0] del out_node.fluid_code.layers[0]
...@@ -357,43 +309,241 @@ class TFOptimizer(object): ...@@ -357,43 +309,241 @@ class TFOptimizer(object):
if node is None: if node is None:
continue continue
if node.layer_type in elementwise_ops: if node.layer_type in elementwise_ops:
if not node.can_be_removed: can_be_removed = True
if node.fluid_code.layers[
-1].op != "transpose" or node.fluid_code.layers[
-1].param_attr["perm"] != [0, 2, 3, 1]:
continue
can_be_removed = True
output_names = node.outputs
for out_name in output_names:
out_node = graph_copy.get_node(out_name)
if len(out_node.fluid_code.layers) < 3:
can_be_removed = False
break
if hasattr(out_node, "can_be_removed"):
if not out_node.can_be_removed:
can_be_removed = False
break
if out_node.layer_type in can_be_optimized_ops:
if out_node.fluid_code.layers[
0].op != "transpose" or out_node.fluid_code.layers[
0].param_attr["perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
elif out_node.layer_type in elementwise_ops:
if out_node.fluid_code.layers[
0].op != "transpose" and out_node.fluid_code.layers[
1].op != "transpose":
can_be_removed = False
break
if out_node.fluid_code.layers[0].op == "transpose":
if out_node.fluid_code.layers[0].param_attr[
"perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
if out_node.fluid_code.layers[1].op == "transpose":
if out_node.fluid_code.layers[1].param_attr[
"perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
if can_be_removed and len(node.fluid_code.layers) > 1:
true_node = self.graph.get_node(node_name) true_node = self.graph.get_node(node_name)
for i, in_name in enumerate(node.inputs): true_node.fluid_code.layers[
in_node = graph_copy.get_node(in_name) -2].output = true_node.fluid_code.layers[-1].output
if hasattr(in_node, "is_nhwc") and in_node.is_nhwc: del true_node.fluid_code.layers[-1]
if i == 0: for out_name in output_names:
l = Layer() out_node = self.graph.get_node(out_name)
l.op = "transpose" if out_node.layer_type in can_be_optimized_ops:
l.inputs = true_node.fluid_code.layers[ out_node.fluid_code.layers[
0].inputs["x"] 1].inputs = out_node.fluid_code.layers[0].inputs
l.param_attr = {"perm": [0, 2, 3, 1]} del out_node.fluid_code.layers[0]
l.output = "nhwc_" + l.inputs.layer_name elif out_node.layer_type in elementwise_ops:
true_node.fluid_code.layers[0].inputs[ if out_node.inputs[0] in node.layer_name:
"x"] = l.output if out_node.fluid_code.layers[
true_node.fluid_code.layers.insert(0, l) 1].op == 'transpose':
elif i == 1: out_node.fluid_code.layers[2].inputs[
l = Layer() 'x'] = out_node.fluid_code.layers[
l.op = "transpose" 0].inputs
l.inputs = true_node.fluid_code.layers[ del out_node.fluid_code.layers[0]
0].inputs["y"] else:
l.param_attr = {"perm": [0, 2, 3, 1]} out_node.fluid_code.layers[1].inputs[
l.output = "nhwc_" + l.inputs.layer_name 'x'] = out_node.fluid_code.layers[
true_node.fluid_code.layers[0].inputs[ 0].inputs
"y"] = l.output del out_node.fluid_code.layers[0]
true_node.fluid_code.layers.insert(0, l) elif out_node.inputs[1] in node.layer_name:
else: if out_node.fluid_code.layers[
raise Exception("Unexpected situation happend") 1].op == 'transpose':
out_node.fluid_code.layers[2].inputs[
'y'] = out_node.fluid_code.layers[
1].inputs
del out_node.fluid_code.layers[1]
else:
out_node.fluid_code.layers[1].inputs[
'y'] = out_node.fluid_code.layers[
0].inputs
del out_node.fluid_code.layers[0]
graph_copy = cp.deepcopy(self.graph)
for node_name in self.graph.topo_sort:
node = graph_copy.get_node(node_name)
if node is None or len(node.fluid_code.layers) < 2:
continue
if node.layer_type in can_be_optimized_ops and node.layer_type != "Placeholder":
if node.fluid_code.layers[
-1].op != "transpose" or node.fluid_code.layers[
-1].param_attr["perm"] != [0, 2, 3, 1]:
continue continue
else: can_be_removed = True
for out_name in node.outputs: output_names = node.outputs
for out_name in output_names:
out_node = graph_copy.get_node(out_name)
if hasattr(out_node, "can_be_removed"):
if not out_node.can_be_removed:
can_be_removed = False
break
if len(out_node.fluid_code.layers) < 2:
can_be_removed = False
break
if out_node.layer_type in can_be_optimized_ops:
if out_node.fluid_code.layers[
0].op != "transpose" or out_node.fluid_code.layers[
0].param_attr["perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
elif out_node.layer_type in elementwise_ops:
if out_node.fluid_code.layers[
0].op != "transpose" and out_node.fluid_code.layers[
1].op != "transpose":
can_be_removed = False
break
if out_node.fluid_code.layers[
0].op == "expand" or out_node.fluid_code.layers[
1].op == "expand":
can_be_removed = False
break
if out_node.fluid_code.layers[0].op == "transpose":
if out_node.fluid_code.layers[0].param_attr[
"perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
if out_node.fluid_code.layers[1].op == "transpose":
if out_node.fluid_code.layers[1].param_attr[
"perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
elif out_node.layer_type not in elementwise_ops and out_node.layer_type not in can_be_optimized_ops:
can_be_removed = False
break
if can_be_removed:
true_node = self.graph.get_node(node_name)
if len(true_node.fluid_code.layers) < 2:
continue
true_node.fluid_code.layers[
-2].output = true_node.fluid_code.layers[-1].output
del true_node.fluid_code.layers[-1]
for out_name in output_names:
out_node = self.graph.get_node(out_name) out_node = self.graph.get_node(out_name)
if out_node.layer_type not in elementwise_ops: if out_node.layer_type in can_be_optimized_ops:
assert out_node.fluid_code.layers[
0].op == "transpose", "unexpected situation happend"
out_node.fluid_code.layers[ out_node.fluid_code.layers[
1].inputs = out_node.fluid_code.layers[0].inputs 1].inputs = out_node.fluid_code.layers[0].inputs
del out_node.fluid_code.layers[0] del out_node.fluid_code.layers[0]
elif out_node.layer_type in elementwise_ops:
if out_node.inputs[0] in node.layer_name:
if out_node.fluid_code.layers[
1].op == 'transpose':
if out_node.fluid_code.layers[
2].op == 'transpose':
out_node.fluid_code.layers[3].inputs[
'x'] = out_node.fluid_code.layers[
0].inputs
else:
out_node.fluid_code.layers[2].inputs[
'x'] = out_node.fluid_code.layers[
0].inputs
del out_node.fluid_code.layers[0]
else:
out_node.fluid_code.layers[1].inputs[
'x'] = out_node.fluid_code.layers[
0].inputs
del out_node.fluid_code.layers[0]
elif out_node.inputs[1] in node.layer_name:
if out_node.fluid_code.layers[
1].op == 'transpose':
out_node.fluid_code.layers[2].inputs[
'y'] = out_node.fluid_code.layers[
1].inputs
del out_node.fluid_code.layers[1]
else:
out_node.fluid_code.layers[1].inputs[
'y'] = out_node.fluid_code.layers[
0].inputs
del out_node.fluid_code.layers[0]
graph_copy = cp.deepcopy(self.graph)
for node_name in self.graph.topo_sort:
node = graph_copy.get_node(node_name)
if node is None:
continue
if node.layer_type in elementwise_ops:
can_be_removed = True
if len(node.fluid_code.layers) < 3:
continue
numTranspose = 0
numNotTranspose = 0
for i in range(len(node.fluid_code.layers)):
if node.fluid_code.layers[i].op == 'transpose':
numTranspose += 1
elif node.fluid_code.layers[i].op != 'expand':
numNotTranspose += 1
if numTranspose > numNotTranspose:
if node.fluid_code.layers[0].op == 'expand':
if node.fluid_code.layers[
1].op != 'transpose' or node.fluid_code.layers[
2].op != 'transpose':
continue
else:
true_node = self.graph.get_node(node_name)
true_node.fluid_code.layers[3].inputs[
'x'] = true_node.fluid_code.layers[1].inputs
true_node.fluid_code.layers[3].inputs[
'y'] = true_node.fluid_code.layers[2].inputs
l = Layer()
l.op = 'transpose'
l.inputs = true_node.fluid_code.layers[3].output
l.param_attr = {'perm': [0, 3, 1, 2]}
if type(l.inputs) == str:
l.output = l.inputs
else:
l.output = l.inputs.layer_name
true_node.fluid_code.layers.append(l)
del true_node.fluid_code.layers[1]
del true_node.fluid_code.layers[1]
else:
if node.fluid_code.layers[
0].op != 'transpose' or node.fluid_code.layers[
1].op != 'transpose':
continue
else:
true_node = self.graph.get_node(node_name)
true_node.fluid_code.layers[2].inputs[
'x'] = true_node.fluid_code.layers[0].inputs
true_node.fluid_code.layers[2].inputs[
'y'] = true_node.fluid_code.layers[1].inputs
l = Layer()
l.op = 'transpose'
l.inputs = true_node.fluid_code.layers[2].output
l.param_attr = {'perm': [0, 3, 1, 2]}
l.output = l.inputs.layer_name
true_node.fluid_code.layers.append(l)
del true_node.fluid_code.layers[0]
del true_node.fluid_code.layers[0]
def make_nchw_input_output(self): def make_nchw_input_output(self):
for i, name in enumerate(self.graph.input_nodes): for i, name in enumerate(self.graph.input_nodes):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册