提交 8f9bd9b6 编写于 作者: C Channingss

update Reshape&elementwise_map

上级 9fe56c11
...@@ -350,7 +350,6 @@ class ONNXGraph(Graph): ...@@ -350,7 +350,6 @@ class ONNXGraph(Graph):
node.out_shapes.append(value_info['shape']) node.out_shapes.append(value_info['shape'])
else: else:
node.out_shapes.append([]) node.out_shapes.append([])
print(layer.name, node.out_shapes)
class ONNXDecoder(object): class ONNXDecoder(object):
......
...@@ -40,6 +40,21 @@ def _const_weight_or_none(node): ...@@ -40,6 +40,21 @@ def _const_weight_or_none(node):
return None return None
def _is_static_shape(shape):
negtive_dims = 0
error_dims = 0
for dim in shape:
if dim < 0:
negtive_dims += 1
if dim != -1:
error_dims += 1
if negtive_dims > 1:
return False
if error_dims > 0:
return False
return True
def _get_same_padding(in_size, kernel_size, stride): def _get_same_padding(in_size, kernel_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride)) new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size pad_size = (new_size - 1) * stride + kernel_size - in_size
...@@ -230,11 +245,35 @@ class OpSet9(): ...@@ -230,11 +245,35 @@ class OpSet9():
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
val_y_shape = val_y.out_shapes[0] val_y_shape = val_y.out_shapes[0]
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
inputs = {}
if len(val_x_shape) < len(val_y_shape): if len(val_x_shape) < len(val_y_shape):
val_x, val_y = val_y, val_x if node.layer_type in ['Mul', 'Add']:
val_y_shape, val_x_shape = val_x_shape, val_y_shape val_x, val_y = val_y, val_x
val_y_shape, val_x_shape = val_x_shape, val_y_shape
inputs = {'x': val_x, 'y': val_y}
elif node.layer_type in ['Sub', 'Div', 'Pow']:
val_x_expand = val_x.layer_name + '_expand'
x_value = _const_weight_or_none(val_x)
if (val_x_shape == [1] or len(val_x_shape) == 0) and x_value:
attr = {
'shape': val_y_shape,
'dtype': string(val_x.dtype),
'value': x_value
if len(val_x_shape) == 0 else x_value[0]
}
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=val_x_expand,
param_attr=attr)
val_x_shape = val_y_shape
inputs = {'x': val_x_expand, 'y': val_y}
else:
assert 'Unsupported situation happened.'
else:
inputs = {'x': val_x, 'y': val_y}
print(node.layer_name)
print(val_x_shape, val_y_shape)
str_y_shape = ','.join(str(e) for e in val_y_shape) str_y_shape = ','.join(str(e) for e in val_y_shape)
str_x_shape = ','.join(str(e) for e in val_x_shape) str_x_shape = ','.join(str(e) for e in val_x_shape)
slice_idx = 0 slice_idx = 0
...@@ -244,7 +283,6 @@ class OpSet9(): ...@@ -244,7 +283,6 @@ class OpSet9():
slice_idx += 1 slice_idx += 1
else: else:
break break
attr = {"name": string(node.layer_name)}
if slice_idx < len(val_y_shape) and slice_idx > 0: if slice_idx < len(val_y_shape) and slice_idx > 0:
val_y_reshaped = val_y_shape[slice_idx:] val_y_reshaped = val_y_shape[slice_idx:]
var_y_reshaped = val_y.layer_name + '_reshaped' var_y_reshaped = val_y.layer_name + '_reshaped'
...@@ -257,13 +295,12 @@ class OpSet9(): ...@@ -257,13 +295,12 @@ class OpSet9():
inputs=val_y, inputs=val_y,
output=var_y_reshaped, output=var_y_reshaped,
param_attr=attr_reshaped) param_attr=attr_reshaped)
inputs = {'x': val_x, 'y': var_y_reshaped} inputs['y'] = var_y_reshaped
node.fluid_code.add_layer( node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr) op_type, inputs=inputs, output=node, param_attr=None)
else: else:
inputs = {'x': val_x, 'y': val_y}
node.fluid_code.add_layer( node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr) op_type, inputs=inputs, output=node, param_attr=None)
@print_mapping_info @print_mapping_info
def place_holder(self, node): def place_holder(self, node):
...@@ -941,7 +978,8 @@ class OpSet9(): ...@@ -941,7 +978,8 @@ class OpSet9():
inputs={'x': val_x}, inputs={'x': val_x},
output=node, output=node,
param_attr={'shape': shape_value.tolist()}) param_attr={'shape': shape_value.tolist()})
elif len(node.out_shapes[0]) > 0: elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
0]):
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
inputs={'x': val_x, inputs={'x': val_x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册