提交 2f01932d 编写于 作者: Y yeliang2258

fix onnx convert

上级 2a420846
...@@ -61,14 +61,16 @@ def _rename_or_remove_weight(weights, ...@@ -61,14 +61,16 @@ def _rename_or_remove_weight(weights,
''' '''
if origin_name not in weights: if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) raise KeyError('{} not a key in {}'.format(origin_name, weights.keys()))
if is_remove: # if is_remove:
# remove weight # # remove weight
data = weights.pop(origin_name) # data = weights.pop(origin_name)
else: # else:
data = weights[origin_name] data = weights[origin_name]
if target_name is not None: if target_name is not None:
# rename weight # rename weight
weights[target_name] = data weights[target_name] = data
if "x2paddle_297" in weights.keys():
print("keep")
def _is_static_shape(shape): def _is_static_shape(shape):
...@@ -169,6 +171,8 @@ class OpSet9(): ...@@ -169,6 +171,8 @@ class OpSet9():
'Floor': ['paddle.floor'], 'Floor': ['paddle.floor'],
'Abs': ['paddle.abs'], 'Abs': ['paddle.abs'],
'Erf': ['paddle.erf'], 'Erf': ['paddle.erf'],
'Sin': ['paddle.sin'],
'Cos': ['paddle.cos'],
} }
def __init__(self, decoder, paddle_graph): def __init__(self, decoder, paddle_graph):
...@@ -229,6 +233,8 @@ class OpSet9(): ...@@ -229,6 +233,8 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def place_holder(self, node): def place_holder(self, node):
if node.name in ["297", "x2paddle_297"]:
print("!!!!!!!find! 1123")
shape = node.out_shapes[0] shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape): for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0: if dim_shape == 0 and i == 0:
...@@ -248,6 +254,7 @@ class OpSet9(): ...@@ -248,6 +254,7 @@ class OpSet9():
node = parameter node = parameter
dtype = node.dtype dtype = node.dtype
shape = node.out_shapes[0] shape = node.out_shapes[0]
if hasattr(node.weight, "shape") and len(node.weight.shape) == 0: if hasattr(node.weight, "shape") and len(node.weight.shape) == 0:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.full", "paddle.full",
...@@ -257,6 +264,7 @@ class OpSet9(): ...@@ -257,6 +264,7 @@ class OpSet9():
shape=[1], shape=[1],
fill_value=node.weight) fill_value=node.weight)
else: else:
print("test point:", node.name)
self.weights[node.name] = node.weight self.weights[node.name] = node.weight
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"self.create_parameter", "self.create_parameter",
...@@ -385,15 +393,21 @@ class OpSet9(): ...@@ -385,15 +393,21 @@ class OpSet9():
**attrs) **attrs)
return return
elif node.layer_type == 'Upsample': elif node.layer_type == 'Upsample':
val_scales = self.graph.get_input_node(node, idx=1, copy=True) if len(node.layer.input) == 2:
self.paddle_graph.add_layer( val_scales = self.graph.get_input_node(node, idx=1, copy=True)
"paddle.slice", self.paddle_graph.add_layer(
inputs={"input": val_scales.name}, "paddle.slice",
outputs=[val_scales.name], inputs={"input": val_scales.name},
axes=[0], outputs=[val_scales.name],
starts=[2], axes=[0],
ends=[4]) starts=[2],
inputs['scale_factor'] = val_scales.name ends=[4])
inputs['scale_factor'] = val_scales.name
else:
val_scales = node.get_attr('scales')[2:]
print(type(val_scales))
print(val_scales)
# inputs['scale_factor'] = val_scales
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
attrs.update({ attrs.update({
...@@ -401,6 +415,8 @@ class OpSet9(): ...@@ -401,6 +415,8 @@ class OpSet9():
"mode": string(mode), "mode": string(mode),
"align_mode": 1 "align_mode": 1
}) })
if len(node.layer.input) == 1:
attrs["scale_factor"] = val_scales
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
if mode == "linear" and len(val_x_shape) == 4: if mode == "linear" and len(val_x_shape) == 4:
attrs["mode"] = string("bilinear") attrs["mode"] = string("bilinear")
...@@ -680,27 +696,28 @@ class OpSet9(): ...@@ -680,27 +696,28 @@ class OpSet9():
axes = node.get_attr('axes') axes = node.get_attr('axes')
if axes is None: if axes is None:
axes = self.graph.get_input_node(node, idx=1, copy=True) axes = self.graph.get_input_node(node, idx=1, copy=True)
if node.name in ["x2paddle_vis_local_cost_volume_3d_0_ExpandDims_5_0"]:
if len(val_x.out_shapes[0]) == 0: print("output_shape:", val_x.out_shapes[0])
if node.name: # if len(val_x.out_shapes[0]) == 0:
self.paddle_graph.add_layer( # if node.name:
'paddle.reshape', # self.paddle_graph.add_layer(
inputs={"x": val_x.name}, # 'paddle.reshape',
outputs=[node.name], # inputs={"x": val_x.name},
shape=[1]) # outputs=[node.name],
# shape=[1])
# else:
if isinstance(axes, list) or isinstance(axes, tuple):
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name},
axis=axes,
outputs=[node.name])
else: else:
if isinstance(axes, list) or isinstance(axes, tuple): self.paddle_graph.add_layer(
self.paddle_graph.add_layer( 'paddle.unsqueeze',
'paddle.unsqueeze', inputs={"x": val_x.name,
inputs={"x": val_x.name}, "axis": axes.name},
axis=axes, outputs=[node.name])
outputs=[node.name])
else:
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name,
"axis": axes.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Shrink(self, node): def Shrink(self, node):
...@@ -716,6 +733,8 @@ class OpSet9(): ...@@ -716,6 +733,8 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def Constant(self, node): def Constant(self, node):
if node.name in ["297", "x2paddle_297"]:
print("!!!!!!!find!")
val_output = self.graph.get_node(node.layer.output[0], copy=True) val_output = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value') value = node.get_attr('value')
...@@ -802,11 +821,21 @@ class OpSet9(): ...@@ -802,11 +821,21 @@ class OpSet9():
val_shape = self.graph.get_input_node(node, idx=1, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_x_dtype = val_x.dtype val_x_dtype = val_x.dtype
name_ones = node.name + '_ones' name_ones = node.name + '_ones'
attr_ones = { shape_values = _const_weight_or_none(val_shape)
'shape': val_shape.name, if shape_values is None:
'dtype': string(val_x_dtype), attr_ones = {
'fill_value': 1 'shape': val_shape.name,
} 'dtype': string(val_x_dtype),
'fill_value': 1
}
else:
print("test:", type(shape_values))
print(shape_values.tolist())
attr_ones = {
'shape': shape_values.tolist(),
'dtype': string(val_x_dtype),
'fill_value': 1
}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.full', inputs={}, outputs=[name_ones], **attr_ones) 'paddle.full', inputs={}, outputs=[name_ones], **attr_ones)
inputs_dict = {'x': name_ones, 'y': val_x.name} inputs_dict = {'x': name_ones, 'y': val_x.name}
...@@ -826,6 +855,8 @@ class OpSet9(): ...@@ -826,6 +855,8 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_shape = indices.out_shapes[0] indices_shape = indices.out_shapes[0]
print("indices_shape:", node.name, " ", indices_shape, " ",
val_x.out_shapes[0])
axis = node.get_attr('axis', 0) axis = node.get_attr('axis', 0)
#assert len( #assert len(
# indices_shape) <= 2, "Gather op don't support dim of indice >2 " # indices_shape) <= 2, "Gather op don't support dim of indice >2 "
...@@ -838,6 +869,11 @@ class OpSet9(): ...@@ -838,6 +869,11 @@ class OpSet9():
outputs=[node.name]) outputs=[node.name])
elif len(val_x.out_shapes[0]) > 1: elif len(val_x.out_shapes[0]) > 1:
if len(indices_shape) == 0: if len(indices_shape) == 0:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices.name],
shape=[-1, ])
gather_ = node.name + '_1' gather_ = node.name + '_1'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.gather', 'paddle.gather',
...@@ -1140,6 +1176,11 @@ class OpSet9(): ...@@ -1140,6 +1176,11 @@ class OpSet9():
starts = node.get_attr('starts') starts = node.get_attr('starts')
ends = node.get_attr('ends') ends = node.get_attr('ends')
axes = node.get_attr('axes') axes = node.get_attr('axes')
output_shape = val_x.out_shapes[0]
if axes is None:
axes = [i for i in range(len(starts))]
print("axes:", axes)
for idx in range(len(ends)): for idx in range(len(ends)):
if ends[idx] > 2**31 - 1: if ends[idx] > 2**31 - 1:
ends[idx] = 2**31 - 1 ends[idx] = 2**31 - 1
...@@ -1974,6 +2015,59 @@ class OpSet9(): ...@@ -1974,6 +2015,59 @@ class OpSet9():
outputs=layer_outputs, outputs=layer_outputs,
output_size=output_shape[2:]) output_size=output_shape[2:])
@print_mapping_info
def Neg(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = node.name + "_y"
dtype = np.dtype(val_x.dtype)
self.paddle_graph.add_layer(
"paddle.full",
inputs={},
outputs=[val_y],
dtype=string(dtype),
shape=[1],
fill_value=-1)
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_x.name,
'y': val_y},
outputs=[node.name])
@print_mapping_info
def SpaceToDepth(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
blocksize = node.get_attr('blocksize')
print(blocksize)
val_x_shape = val_x.out_shapes[0]
b, c, h, w = val_x_shape
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[b, c, h // blocksize, blocksize, w // blocksize, blocksize])
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[node.name],
perm=[0, 3, 5, 1, 2, 4])
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": node.name},
outputs=[node.name],
shape=[b, c * (blocksize**2), h // blocksize, w // blocksize])
@print_mapping_info
def GatherElements(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
dtype = np.dtype(val_x.dtype)
self.paddle_graph.add_layer(
"paddle.gather",
inputs={'x': val_x.name,
'index': indices.name},
axis=node.get_attr('axis'),
outputs=[node.name])
@print_mapping_info @print_mapping_info
def GlobalAveragePool(self, node): def GlobalAveragePool(self, node):
op_name = name_generator("pool", self.nn_name2id) op_name = name_generator("pool", self.nn_name2id)
...@@ -2072,6 +2166,7 @@ class OpSet9(): ...@@ -2072,6 +2166,7 @@ class OpSet9():
remove_weight) remove_weight)
if has_bias: if has_bias:
remove_bias = True if val_b.name in self.done_weight_list else False remove_bias = True if val_b.name in self.done_weight_list else False
remove_bias = False
if remove_bias: if remove_bias:
self.done_weight_list.append(val_b_name) self.done_weight_list.append(val_b_name)
_rename_or_remove_weight(self.weights, val_b.name, _rename_or_remove_weight(self.weights, val_b.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册