提交 2ca2e71b 编写于 作者: C channingss

update onnx symbolic shape inference

上级 3aa8e577
...@@ -346,8 +346,12 @@ class ONNXGraph(Graph): ...@@ -346,8 +346,12 @@ class ONNXGraph(Graph):
#if len(value_info['shape']) == 0 or value_info[ #if len(value_info['shape']) == 0 or value_info[
# 'dtype'] is None or 0 in value_info['shape']: # 'dtype'] is None or 0 in value_info['shape']:
# #TODO add node shape inference # #TODO add node shape inference
shape = value_info['shape']
for idx in range(len(shape)):
if shape[idx] == 0:
shape[idx] = -1
node.out_shapes.append(shape)
node.dtype = value_info['dtype'] node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
else: else:
node.out_shapes.append([]) node.out_shapes.append([])
......
...@@ -57,6 +57,7 @@ def _is_static_shape(shape): ...@@ -57,6 +57,7 @@ def _is_static_shape(shape):
return False return False
return True return True
def _get_same_padding(in_size, kernel_size, stride): def _get_same_padding(in_size, kernel_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride)) new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size pad_size = (new_size - 1) * stride + kernel_size - in_size
...@@ -348,6 +349,7 @@ class OpSet9(): ...@@ -348,6 +349,7 @@ class OpSet9():
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear' 'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear'
) )
fluid_op = 'resize_bilinear' fluid_op = 'resize_bilinear'
attr['align_corners'] = False
node.fluid_code.add_layer( node.fluid_code.add_layer(
fluid_op, inputs=inputs, output=node, param_attr=attr) fluid_op, inputs=inputs, output=node, param_attr=attr)
...@@ -736,53 +738,59 @@ class OpSet9(): ...@@ -736,53 +738,59 @@ class OpSet9():
param_attr=None) param_attr=None)
else: else:
input_inner_indices = node.layer_name + '_input_inner_indices' input_inner_indices = node.layer_name + '_input_inner_indices'
shape = val_x.out_shapes[0]
node.fluid_code.add_layer(
'reshape',
inputs=indices.layer_name,
output=indices.layer_name,
param_attr={'shape': indices.out_shapes[0]})
zeros_like_val_x = val_x.layer_name + '_zeros'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'scatter_nd', 'zeros_like',
inputs=val_x,
output=zeros_like_val_x,
param_attr=None)
node.fluid_code.add_layer(
'scatter_nd_add',
inputs={ inputs={
'shape': val_x.out_shapes[0], 'ref': zeros_like_val_x,
'index': indices, 'index': indices,
'updates': updates 'updates': updates
}, },
output=input_inner_indices, output=input_inner_indices,
param_attr=None) param_attr=None)
indices_mask = node.layer_name + '_indices_mask'
constant_minus_one = node.layer_name + '_constant_minus_one' constant_minus_one = node.layer_name + '_constant_minus_one'
# full_like support create tensor shape like input tensor
node.fluid_code.add_layer( node.fluid_code.add_layer(
'fill_constant', 'full_like',
inputs=None, inputs=updates,
output=constant_minus_one, output=constant_minus_one,
param_attr={ param_attr={'dtype': string(updates.dtype),
'shape': updates.out_shapes[0], 'fill_value': -1})
'dtype': string(updates.dtype),
'value': -1
})
indices_mask = node.layer_name + '_indices_mask'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'scatter_nd', 'scatter_nd_add',
inputs={ inputs={
'shape': val_x.out_shapes[0], 'ref': zeros_like_val_x,
'index': indices, 'index': indices,
'updates': constant_minus_one 'updates': constant_minus_one
}, },
output=indices_mask, output=indices_mask,
param_attr=None) param_attr=None)
constant_one = node.layer_name + '_constant_1'
constant_1 = node.layer_name + '_constant_1' # full_like support create tensor shape like input tensor
node.fluid_code.add_layer( node.fluid_code.add_layer(
'fill_constant', 'full_like',
inputs=None, inputs=val_x,
output=constant_1, output=constant_one,
param_attr={ param_attr={'dtype': string(val_x.dtype),
'shape': val_x.out_shapes[0], 'fill_value': 1})
'dtype': string(val_x.dtype),
'value': 1
})
input_out_indices_mask = node.layer_name + '_input_out_indices_mask' input_out_indices_mask = node.layer_name + '_input_out_indices_mask'
node.fluid_code.add_layer( node.fluid_code.add_layer(
"elementwise_add", "elementwise_add",
inputs={"x": indices_mask, inputs={"x": indices_mask,
"y": constant_1}, "y": constant_one},
output=input_out_indices_mask, output=input_out_indices_mask,
param_attr=None) param_attr=None)
...@@ -841,11 +849,15 @@ class OpSet9(): ...@@ -841,11 +849,15 @@ class OpSet9():
self.omit_nodes.append(ends.layer_name) self.omit_nodes.append(ends.layer_name)
starts_value = starts_value.copy() starts_value = starts_value.copy()
ends_value = ends_value.copy() ends_value = ends_value.copy()
#for idx in range(len(ends_value)):
# if ends_value[idx] > 2**31 - 1:
# ends_value[idx] = 2**31 - 1
#print(val_x.out_shapes)
for idx in range(len(ends_value)): for idx in range(len(ends_value)):
if starts_value[idx] > val_x.out_shapes[0][axes[idx]]: if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]:
starts_value[idx] = val_x.out_shapes[0][axes[idx]]-1 starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
ends_value[idx] = val_x.out_shapes[0][axes[idx]] ends_value[idx] = val_x.out_shapes[0][axes[idx]]
starts_value[idx] = val_x.out_shapes[0][axes[idx]]-1 starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
elif ends_value[idx] > 2**31 - 1: elif ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1 ends_value[idx] = 2**31 - 1
attr = { attr = {
...@@ -882,10 +894,10 @@ class OpSet9(): ...@@ -882,10 +894,10 @@ class OpSet9():
if steps is not None: if steps is not None:
attr['strides'] = steps attr['strides'] = steps
node.fluid_code.add_layer( node.fluid_code.add_layer(
'strided_slice', inputs=val_x, output=node, param_attr=attr) 'strided_slice', inputs=val_x, output=node, param_attr=attr)
else: else:
node.fluid_code.add_layer( node.fluid_code.add_layer(
'slice', inputs=val_x, output=node, param_attr=attr) 'slice', inputs=val_x, output=node, param_attr=attr)
@print_mapping_info @print_mapping_info
def ConstantOfShape(self, node): def ConstantOfShape(self, node):
...@@ -928,15 +940,12 @@ class OpSet9(): ...@@ -928,15 +940,12 @@ class OpSet9():
min_value = _const_weight_or_none(min_ipt) min_value = _const_weight_or_none(min_ipt)
self.omit_nodes.append(max_ipt.layer_name) self.omit_nodes.append(max_ipt.layer_name)
self.omit_nodes.append(min_ipt.layer_name) self.omit_nodes.append(min_ipt.layer_name)
if max_value.shape == (1,): if max_value.shape == (1, ):
max_value = max_value[0] max_value = max_value[0]
if min_value.shape == (1,): if min_value.shape == (1, ):
min_value = min_value[0] min_value = min_value[0]
if max_value is not None and min_value is not None: if max_value is not None and min_value is not None:
attr = { attr = {'max': max_value, 'min': min_value}
'max': max_value,
'min': min_value
}
node.fluid_code.add_layer( node.fluid_code.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr) 'clip', inputs=val_x, output=node, param_attr=attr)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册