提交 3aa8e577 编写于 作者: C Channingss

fix bug of shape infer

上级 09d35587
...@@ -1419,7 +1419,7 @@ class SymbolicShapeInference: ...@@ -1419,7 +1419,7 @@ class SymbolicShapeInference:
if self.verbose_ > 2: if self.verbose_ > 2:
print(node.op_type + ': ' + node.name) print(node.op_type + ': ' + node.name)
for i, name in enumerate(node.input): for i, name in enumerate(node.input):
print(' Input {}: {} {}€5€5€5€5€5'.format( print(' Input {}: {} {}'.format(
i, name, 'initializer' i, name, 'initializer'
if name in self.initializers_ else '')) if name in self.initializers_ else ''))
...@@ -1544,7 +1544,7 @@ class SymbolicShapeInference: ...@@ -1544,7 +1544,7 @@ class SymbolicShapeInference:
continue # continue the inference after guess, no need to stop as no merge is needed continue # continue the inference after guess, no need to stop as no merge is needed
if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
print('Stopping at incomplete shape inference at ' + print('Stopping at incomplete symbolic shape inference at ' +
node.op_type + ': ' + node.name) node.op_type + ': ' + node.name)
print('node inputs:') print('node inputs:')
for i in node.input: for i in node.input:
...@@ -1579,6 +1579,7 @@ class SymbolicShapeInference: ...@@ -1579,6 +1579,7 @@ class SymbolicShapeInference:
all_shapes_inferred = False all_shapes_inferred = False
symbolic_shape_inference._preprocess( symbolic_shape_inference._preprocess(
in_mp, input_shapes=fixed_input_shape) in_mp, input_shapes=fixed_input_shape)
try: try:
while symbolic_shape_inference.run_: while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl( all_shapes_inferred = symbolic_shape_inference._infer_impl(
...@@ -1588,9 +1589,8 @@ class SymbolicShapeInference: ...@@ -1588,9 +1589,8 @@ class SymbolicShapeInference:
print('!' * 10) print('!' * 10)
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_) symbolic_shape_inference.out_mp_)
#onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
except: except:
print('Stopping at incomplete shape inference') print('Stopping at incomplete symbolic shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_) in_mp)
return symbolic_shape_inference.out_mp_.graph return symbolic_shape_inference.out_mp_.graph
...@@ -57,7 +57,6 @@ def _is_static_shape(shape): ...@@ -57,7 +57,6 @@ def _is_static_shape(shape):
return False return False
return True return True
def _get_same_padding(in_size, kernel_size, stride): def _get_same_padding(in_size, kernel_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride)) new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size pad_size = (new_size - 1) * stride + kernel_size - in_size
...@@ -104,14 +103,6 @@ class OpSet9(): ...@@ -104,14 +103,6 @@ class OpSet9():
default_op_mapping = { default_op_mapping = {
'Shape': ['shape', ['X'], ['Out']], 'Shape': ['shape', ['X'], ['Out']],
'Clip': [
'clip', ['X'], ['Out'], dict(), dict(
min=(np.asarray(
[255, 255, 127, 255], dtype=np.uint8).view(np.float32)[0]),
max=(np.asarray(
[255, 255, 127, 127], dtype=np.uint8).view(np.float32)[0]),
)
],
'Erf': ['erf', ['X'], ['Out']], 'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']], 'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [ 'ReduceMean': [
...@@ -831,27 +822,31 @@ class OpSet9(): ...@@ -831,27 +822,31 @@ class OpSet9():
if len(node.inputs) > 1: if len(node.inputs) > 1:
starts = self.graph.get_input_node(node, idx=1, copy=True) starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True) ends = self.graph.get_input_node(node, idx=2, copy=True)
starts_value = _const_weight_or_none(starts)
ends_value = _const_weight_or_none(ends)
if len(node.inputs) > 3: if len(node.inputs) > 3:
axes = self.graph.get_input_node(node, idx=3, copy=True) axes = self.graph.get_input_node(node, idx=3, copy=True)
axes = _const_weight_or_none(axes, necessary=True) axes = _const_weight_or_none(axes, necessary=True)
if len(node.inputs) > 4: if len(node.inputs) > 4:
steps = self.graph.get_input_node(node, idx=4, copy=True) steps = self.graph.get_input_node(node, idx=4, copy=True)
steps = _const_weight_or_none(steps) steps = _const_weight_or_none(steps)
if steps is not None:
assert steps == 1, "Only support convert op:Slice, which attribute:steps == 1"
attr = { attr = {
"axes": axes, "axes": axes,
"starts": starts.layer_name, "starts": starts.layer_name,
"ends": ends.layer_name "ends": ends.layer_name
} }
starts_value = _const_weight_or_none(starts)
ends_value = _const_weight_or_none(ends)
if starts_value is not None and ends_value is not None: if starts_value is not None and ends_value is not None:
self.omit_nodes.append(starts.layer_name) self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name) self.omit_nodes.append(ends.layer_name)
starts_value = starts_value.copy()
ends_value = ends_value.copy() ends_value = ends_value.copy()
for idx in range(len(ends_value)): for idx in range(len(ends_value)):
if ends_value[idx] > 2**31 - 1: if starts_value[idx] > val_x.out_shapes[0][axes[idx]]:
starts_value[idx] = val_x.out_shapes[0][axes[idx]]-1
ends_value[idx] = val_x.out_shapes[0][axes[idx]]
starts_value[idx] = val_x.out_shapes[0][axes[idx]]-1
elif ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1 ends_value[idx] = 2**31 - 1
attr = { attr = {
"axes": axes, "axes": axes,
...@@ -869,12 +864,12 @@ class OpSet9(): ...@@ -869,12 +864,12 @@ class OpSet9():
attr['starts'] = starts_cast attr['starts'] = starts_cast
if ends.dtype != 'int32': if ends.dtype != 'int32':
ends_cast = ends.layer_name + '_cast' ends_cast = ends.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'cast', 'cast',
inputs=ends, inputs=ends,
output=ends_cast, output=ends_cast,
param_attr={'dtype': string('int32')}) param_attr={'dtype': string('int32')})
attr['ends'] = ends_cast attr['ends'] = ends_cast
else: else:
starts = node.get_attr('starts') starts = node.get_attr('starts')
ends = node.get_attr('ends') ends = node.get_attr('ends')
...@@ -884,7 +879,12 @@ class OpSet9(): ...@@ -884,7 +879,12 @@ class OpSet9():
ends[idx] = 2**31 - 1 ends[idx] = 2**31 - 1
attr = {"axes": axes, "starts": starts, "ends": ends} attr = {"axes": axes, "starts": starts, "ends": ends}
node.fluid_code.add_layer( if steps is not None:
attr['strides'] = steps
node.fluid_code.add_layer(
'strided_slice', inputs=val_x, output=node, param_attr=attr)
else:
node.fluid_code.add_layer(
'slice', inputs=val_x, output=node, param_attr=attr) 'slice', inputs=val_x, output=node, param_attr=attr)
@print_mapping_info @print_mapping_info
...@@ -907,6 +907,41 @@ class OpSet9(): ...@@ -907,6 +907,41 @@ class OpSet9():
node.fluid_code.add_layer( node.fluid_code.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr) 'fill_constant', inputs=None, output=node, param_attr=attr)
@print_mapping_info
def Clip(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
max_value, min_value = None, None
if len(node.inputs) == 1:
max_value = node.get_attr('max')
min_value = node.get_attr('min')
attr = {
'max': max_value,
'min': min_value,
}
node.fluid_code.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr)
else:
max_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_ipt = self.graph.get_input_node(node, idx=2, copy=True)
max_value = _const_weight_or_none(max_ipt)
min_value = _const_weight_or_none(min_ipt)
self.omit_nodes.append(max_ipt.layer_name)
self.omit_nodes.append(min_ipt.layer_name)
if max_value.shape == (1,):
max_value = max_value[0]
if min_value.shape == (1,):
min_value = min_value[0]
if max_value is not None and min_value is not None:
attr = {
'max': max_value,
'min': min_value
}
node.fluid_code.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr)
else:
raise
@print_mapping_info @print_mapping_info
def Split(self, node): def Split(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册