提交 de1648a8 编写于 作者: M Macrobull

fix bugs in ONNX optimization, ops, empty list in OpDesc.attrs

上级 f484a779
......@@ -35,6 +35,7 @@ idx = 0
#
#
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4))
#yp = model(xb)
#idx += 1
......@@ -56,6 +57,7 @@ idx = 0
#
#
#model = Model()
#model.eval()
#xb = torch.rand((2, 3))
#yp = model(xb)
#idx += 1
......@@ -79,6 +81,7 @@ class Model(nn.Module):
model = Model()
model.eval()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
......@@ -105,6 +108,7 @@ class Model(nn.Module):
model = Model()
model.eval()
xb0 = torch.rand((2, 3))
xb1 = torch.rand((2, 3))
ya, yb, yc = model(xb0, xb1)
......@@ -129,6 +133,7 @@ class Model(nn.Module):
model = Model()
model.eval()
theta = torch.rand((2, 2, 3))
grid = model(theta)
idx += 1
......@@ -156,6 +161,7 @@ class Model(nn.Module):
model = Model()
model.eval()
xb = torch.rand((2, 3, 4, 5))
yp = model(xb)
idx += 1
......@@ -185,6 +191,7 @@ class Model(nn.Module):
model = Model()
model.eval()
xb = torch.rand((2, 3, 4, 5))
yp = model(xb)
idx += 1
......@@ -209,6 +216,7 @@ export_onnx_with_validation(
#
#
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4, 5))
#yp = model(xb)
#idx += 1
......@@ -229,6 +237,7 @@ class Model(nn.Module):
model = Model()
model.eval()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
......
......@@ -77,6 +77,7 @@ def convert(onnx_model_filename,
logger.warning('the ONNX model sanity checking error is suppressed')
logger.warning('value_info inferring may be uncompleted')
# onnx model optimization
logger.info('model has %d ops', len(onnx_model.graph.node))
logger.info('optimizing model ...')
onnx_model = optimize_model_skip_op_for_inference(onnx_model)
onnx_model = optimize_model_strip_initializer(onnx_model)
......@@ -142,7 +143,8 @@ def convert(onnx_model_filename,
raise e
op_codes = fluid_program.codes
fluid_program.codes = []
logger.info('%d ops converted', len(fluid_program.op_descs))
logger.info('%d ops in, %d ops out', len(onnx_graph.node),
len(fluid_program.op_descs))
# weight writer
for name, weight in graph_weights(onnx_graph):
......
......@@ -127,8 +127,10 @@ def node_topo(nodes, topo='default'):
return list(range(len(nodes)))
node_topo = []
node_in_degrees = [len(node.input) for node in nodes]
node_out_degrees = [len(node.output) for node in nodes]
node_in_degrees = [len(set(node.input))
for node in nodes] # merge multiple references
node_out_degrees = [len(set(node.output))
for node in nodes] # merge multiple references
input_refs, output_refs = build_value_refs(nodes)
if topo == 'forward':
......@@ -395,7 +397,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
ret_inputs.add().CopyFrom(item)
else:
logger.debug('input %s(%s%s) stripped', name, tensor_dtype(item),
tensor_shape(item))
tuple(tensor_shape(item)))
return ret
......@@ -422,7 +424,7 @@ def optimize_model_cast(model):
attrs = node_attrs(node)
output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']]
input_name = node.input[0]
info = value_info.get('input_name', None) # relax for un-inferrable
info = value_info.get(input_name, None) # relax for un-inferrable
if info is None:
continue
input_dtype = info.get('dtype', None)
......
......@@ -83,7 +83,7 @@ DEFAULT_OP_MAPPING = {
'And': ['logical_and', ['X', 'Y'], ['Out']],
'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), [1, 0], None, False],
'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x vs transpose_X
'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
......@@ -444,7 +444,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{}{} = layers.{}({}, exclusive=True'
prog.Code('{} = layers.{}({}, exclusive=True'
', pool_size={}'
', pool_type={}'
', pool_stride={}'
......@@ -452,7 +452,6 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
', ceil_mode={}'
'{})'.format(
var_y,
', {}'.format(var_indices) if has_indices else '',
fluid_op,
var_x,
# attrs
......@@ -529,7 +528,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
))
prog.VarDesc(var_y)
if is_max_pool:
var_argmax = _make_var_name(name + '.argmax') # implicit variable
var_argmax = _make_var_name(name + '.argmax') # hidden variable
prog.VarDesc(var_argmax)
prog.OpDesc(
fluid_op,
......@@ -664,7 +663,7 @@ def BatchNormalization(prog,
repr(var_scale), repr(var_b), repr(var_mean),
repr(var_var))
# generationvalue_infos
# generation
prog.Code('{} = layers.{}({}, is_test=True, data_layout="NCHW"'
', momentum={}'
', epsilon={}'
......@@ -804,7 +803,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'using value as 1-D tensor may lead to fails', outputs, val_output)
# generation
if value.size == 1: # scalar
value = value.tolist()
if len(value) == 1: # scalar
value = value[0]
fluid_op = 'fill_constant'
prog.Code('{} = layers.{}(shape={}, dtype={}, value={})'.format(
......@@ -815,7 +815,6 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
repr(dtype.name),
value,
))
value_infos[val_output]['const_value'] = value
prog.VarDesc(var_output)
prog.OpDesc(
fluid_op,
......@@ -823,16 +822,15 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
([var_output], 'Out'),
dict(
shape=shape,
dtype=dtype.name,
dtype=prog.Dtype(dtype),
value=value,
),
)
else: # list parameter -> const_value
prog.Code('# {} = {} # passed directly as literal'.format(
var_output,
value.tolist(),
))
value_infos[val_output]['const_value'] = value.tolist()
var_output, value))
value_infos[val_output]['const_value'] = value
def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
......@@ -1553,7 +1551,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
prog.VarDesc(var_output)
prog.OpDesc(
fluid_op,
([var_data], 'X'),
([var_data], 'Input'),
([var_output], 'Out'),
dict(
axes=axes,
......@@ -1615,17 +1613,13 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
prog.Code('# repeats:{}={} # const as literal'.format(var_repeats, repeats))
prog.Code('{} = layers.{}({}'
', expand_times={}'
'{})'
' # {} = {}'.format(
'{})'.format(
var_output,
fluid_op,
var_input,
# attrs
repeats,
name_attr,
# comment
_make_var_name(val_repeats),
repeats,
))
prog.VarDesc(var_output)
prog.OpDesc(
......
......@@ -54,7 +54,7 @@ def validate(fluid_model_filename,
# load model
fluid_model_dir, basename = os.path.split(fluid_model_filename)
if basename == '__model__': # is desc model
if basename == '__model__': # is desc program
logger.debug('using desc file %s', basename)
prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe)
out_names = var_outs # HINT: pass var if fetch ops already created
......
......@@ -139,8 +139,10 @@ class Program(object):
elif isinstance(value, str):
od_attr.type = framework_pb2.STRING
od_attr.s = value
elif isinstance(value, list) and len(value) > 0:
if isinstance(value, bool): # bool.mro() = [bool, int, object]
elif isinstance(value, list):
if len(value) > 0:
if isinstance(value,
bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEANS
od_attr.bools.extend(value)
elif isinstance(value[0], int): # only cast to int32 list
......@@ -152,6 +154,16 @@ class Program(object):
elif isinstance(value[0], str):
od_attr.type = framework_pb2.STRINGS
od_attr.strings.extend(value)
else:
raise ValueError('unsupported attribute {} = {}'.format(
key, value))
else: # WORKAROUND: shape of scalars is []
od_attr.type = framework_pb2.INTS
logger.warning('using attribute %s = %s as INTS', key,
value)
else:
raise ValueError('unsupported attribute {} = {}'.format(
key, value))
od_attrs.append(od_attr)
return od_attrs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册