提交 826481c4 编写于 作者: M Macrobull

add type shape inference

上级 7c3e9379
......@@ -14,14 +14,14 @@ from collections import OrderedDict as Dict
def make_var_name(name):
"""
make a valid variable name in Python code
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:': #
for s in ' \\|/:-': #
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
......
......@@ -17,14 +17,14 @@ from glob import glob
def make_var_name(name):
"""
make a valid variable name in Python code
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:': #
for s in ' \\|/:-': #
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
......
......@@ -311,7 +311,7 @@ resnet100_arcface()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid -o /tmp/export/ "$fn_model" -y
python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
......
......@@ -95,6 +95,14 @@ parser.add_argument(
default=1e-2,
help='assertion relative tolerance for validation',
)
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help='perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
......
......@@ -60,19 +60,28 @@ def main(**kwargs):
# validate
passed = True
golden_data_filename = kwargs.pop('test_data', '')
if golden_data_filename:
infer_inputs = kwargs.pop('infer_inputs', None)
if golden_data_filename or infer_inputs:
from .validation import validate
save_inference_model = infer_inputs is not None
inference_input_names = infer_inputs.split(
',') if infer_inputs else None
logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'),
golden_data_filename, **kwargs)
golden_data_filename=golden_data_filename,
save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs)
logger.info('starting validation on code ...')
# this re-generate desc proto with Python code when debug on
passed &= validate(shutil.os.path.join(save_dir, model_basename),
golden_data_filename,
golden_data_filename=golden_data_filename,
model_func_name=model_func_name,
save_inference_model=debug,
save_inference_model=save_inference_model,
inference_input_names=inference_input_names,
**kwargs)
if not passed:
......
......@@ -27,8 +27,8 @@ def convert(onnx_model_filename,
debug=False,
**kwargs):
"""
convert an ONNX model to Paddle fluid Python code and desc pb
"""
convert an ONNX model to Paddle fluid Python code and desc pb
"""
import onnx
......@@ -141,23 +141,22 @@ def convert(onnx_model_filename,
logger.info('%d ops in, %d ops out', len(onnx_graph.node),
len(fluid_program.op_descs))
# shape-inference
# type-shape inference
for name, value_info in graph_value_infos.items():
var_name = make_var_name(name)
fluid_program.VarTypeInfo(var_name, value_info,
remove_batch=False) # shape-infer only
fluid_program.VarTypeShapeInfo(var_name, value_info,
remove_batch=False) # shape-infer only
bad_var_names = []
for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'):
bad_var_names.append(var_name)
if len(bad_var_names) > 0:
logger.warning('type info not infered for var %s ...',
logger.warning('type-shape not infered for var %s ...',
', '.join(bad_var_names[:5]))
logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly')
logger.warning(
'please consider adding option -d to invoke PaddlePaddle shape-inference'
)
logger.warning('please consider running onnx2fluid.validation with -i '
'to invoke PaddlePaddle type-shape inference')
# weight writer
for name, weight in graph_weights(onnx_graph):
......@@ -233,13 +232,9 @@ def convert(onnx_model_filename,
logger.info('conversion finished')
if __name__ == '__main__':
del convert
def main():
import argparse
from onnx2fluid.conversion import convert
parser = argparse.ArgumentParser(
description='onnx2fluid.convert',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
......@@ -310,3 +305,11 @@ if __name__ == '__main__':
onnx_opset_pedantic=pedantic,
onnx_skip_version_conversion=skip_version_conversion,
debug=debug)
if __name__ == '__main__':
del convert
from onnx2fluid.conversion import convert
main()
......@@ -44,8 +44,8 @@ DEFAULT_OP_DOMAIN = 'ai.onnx'
def print_pb_structure(message, loop_iterative=False, depth=0):
"""
print pb fields in its structure
"""
print pb fields in its structure
"""
if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'):
for field in message.DESCRIPTOR.fields:
......@@ -65,8 +65,8 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
def build_value_refs(nodes):
"""
build op reference of inputs and outputs
"""
build op reference of inputs and outputs
"""
input_refs = Dict()
output_refs = Dict()
......@@ -80,8 +80,8 @@ def build_value_refs(nodes):
def get_attribute_value2(attr):
"""
get_attribute_value enhanced
"""
get_attribute_value enhanced
"""
if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
......@@ -99,24 +99,24 @@ def get_attribute_value2(attr):
def tensor_dtype(tensor):
"""
get ONNX tensor in np.dtype
"""
get ONNX tensor in np.dtype
"""
return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
def tensor_shape(tensor):
"""
get ONNX tensor shape
"""
get ONNX tensor shape
"""
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]
def node_attrs(node):
"""
convert ONNX node attributes to dict
"""
convert ONNX node attributes to dict
"""
return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
......@@ -124,8 +124,8 @@ def node_attrs(node):
def node_topo(nodes, topo='default'):
"""
build indices with given topology to an ONNX node graph
"""
build indices with given topology to an ONNX node graph
"""
if topo == 'default':
return list(range(len(nodes)))
......@@ -192,8 +192,8 @@ def node_topo(nodes, topo='default'):
def node_iter(nodes, indices=None):
"""
generator for ONNX node graph with given indices
"""
generator for ONNX node graph with given indices
"""
if indices is None:
indices = range(len(nodes))
......@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None):
if name == '':
name = 'op_' + str(index)
else: # make_op_name
for s in ' \\|/:': #
for s in ' \\|/:-': #
name = name.replace(s, '_')
if domain == '':
domain = DEFAULT_OP_DOMAIN
......@@ -220,8 +220,8 @@ def node_iter(nodes, indices=None):
def graph_ops(graph, topo='default'):
"""
generator for ONNX node graph with given topology
"""
generator for ONNX node graph with given topology
"""
if not isinstance(graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
......@@ -232,8 +232,8 @@ def graph_ops(graph, topo='default'):
def graph_weights(graph):
"""
generator for weights of an ONNX model
"""
generator for weights of an ONNX model
"""
if not isinstance(graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
......@@ -247,8 +247,8 @@ def graph_weights(graph):
def inferred_model_value_info(model):
"""
collect value/type info for an ONNX model
"""
collect value/type info for an ONNX model
"""
model = infer_shapes(model)
graph = model.graph
......@@ -278,8 +278,8 @@ def inferred_model_value_info(model):
def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
"""
skip nodes between src_output_name -> dst_input_name and connect this pair
"""
skip nodes between src_output_name -> dst_input_name and connect this pair
"""
processed = 0
for next_idx in input_refs[src_output_name]:
......@@ -293,8 +293,8 @@ def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
"""
skip nodes between dst_output_name -> src_input_name and connect this pair
"""
skip nodes between dst_output_name -> src_input_name and connect this pair
"""
processed = 0
for prev_idx in output_refs[src_input_name]:
......@@ -308,8 +308,8 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
def optimize_model_skip_op_for_inference(model, op_list=None):
"""
skip ops can be bypassed for inference
"""
skip ops can be bypassed for inference
"""
if op_list is None:
op_list = ('Dropout', 'Identity')
......@@ -369,8 +369,8 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
def optimize_model_strip_initializer(model, keep_input_only=True):
"""
strip weights for inference
"""
strip weights for inference
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
......@@ -410,8 +410,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
def optimize_model_cast(model):
"""
strip cascade and unecessary onnx::Cast-9:
"""
strip cascade and unecessary onnx::Cast-9:
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
......@@ -467,8 +467,8 @@ def optimize_model_cast(model):
def optimize_model_slice(model):
"""
strip cascade and unecessary onnx::Slice-1:9
"""
strip cascade and unecessary onnx::Slice-1:9
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
......
......@@ -41,74 +41,74 @@ DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True
DEFAULT_OP_MAPPING_VALUES = list(DEFAULT_OP_MAPPING_FIELD_VALUES.values())
DEFAULT_OP_MAPPING = {
## nil ops ##
'RandomUniform':
['uniform_random', [], ['Out'], dict(high='max', low='min'),
dict(), None, None, False],
'RandomNormal':
['gaussian_random', [], ['Out'], dict(scale='std'),
dict(), None, None, False],
## unary ops ##
'Abs': ['abs', ['X'], ['Out']],
'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims='')],
'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims='')],
'Ceil': ['ceil', ['X'], ['Out']],
'Clip': ['clip', ['X'], ['Out']], # attrs bypassed
'Cos': ['cos', ['X'], ['Out']],
'Elu': ['elu', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Flatten': ['flatten', ['X'], ['Out']], # attrs bypassed, FIXME: emit flatten2
'Floor': ['floor', ['X'], ['Out']],
'Gather': ['gather', ['X'], ['Out'], dict(axis='')],
'LeakyRelu': ['leaky_relu', ['X'], ['Out']],
'Log': ['log', ['X'], ['Out']],
'LRN': ['lrn', ['X'], ['Out', 'MidOut'], dict(size='n', bias='k')], #
'Reciprocal': ['reciprocal', ['X'], ['Out']],
'Relu': ['relu', ['X'], ['Out']],
'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')],
'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 vs int32
'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'Sin': ['sin', ['X'], ['Out']],
'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2
'Softplus': ['softplus', ['X'], ['Out']],
# FIXME: default axis = -1, reshape required before and after
'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')],
'Softsign': ['softsign', ['X'], ['Out']],
'Sqrt': ['sqrt', ['X'], ['Out']],
'Tanh': ['tanh', ['X'], ['Out']],
'ThresholdedRelu': ['thresholded_relu', ['X'], ['Out'], dict(alpha='threshold')],
#'Transpose': ['transpose', ['X'], ['Out']],
'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2
## binary ops ##
'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
#'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
'And': ['logical_and', ['X', 'Y'], ['Out']],
'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), [1, 0], None, False],
'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x vs transpose_X
'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Not': ['logical_not', ['X', 'Y'], ['Out']],
'OneHot': # assuming values=[0, 1], axis=-1 and drop them
['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(),
[0, 1], None, False],
'Or': ['logical_or', ['X', 'Y'], ['Out']],
'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], # TODO: pow for scalar exponent
'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Xor': ['logical_xor', ['X', 'Y'], ['Out']],
# reduce ops
'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceMean': ['reduce_mean', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceMin': ['reduce_min', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceProd': ['reduce_prod', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceSum': ['reduce_sum', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
# other ops
'Scatter': ['scatter', ['X', 'Index', 'Updates'], ['Out']],
'TopK': ['topk', ['X', 'K'], ['Out', 'Indices']],
## nil ops ##
'RandomUniform':
['uniform_random', [], ['Out'], dict(high='max', low='min'),
dict(), None, None, False],
'RandomNormal':
['gaussian_random', [], ['Out'], dict(scale='std'),
dict(), None, None, False],
## unary ops ##
'Abs': ['abs', ['X'], ['Out']],
'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims='')],
'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims='')],
'Ceil': ['ceil', ['X'], ['Out']],
'Clip': ['clip', ['X'], ['Out']], # attrs bypassed
'Cos': ['cos', ['X'], ['Out']],
'Elu': ['elu', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Flatten': ['flatten', ['X'], ['Out']], # attrs bypassed, FIXME: emit flatten2
'Floor': ['floor', ['X'], ['Out']],
'Gather': ['gather', ['X'], ['Out'], dict(axis='')],
'LeakyRelu': ['leaky_relu', ['X'], ['Out']],
'Log': ['log', ['X'], ['Out']],
'LRN': ['lrn', ['X'], ['Out', 'MidOut'], dict(size='n', bias='k')], #
'Reciprocal': ['reciprocal', ['X'], ['Out']],
'Relu': ['relu', ['X'], ['Out']],
'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')],
'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 vs int32
'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'Sin': ['sin', ['X'], ['Out']],
'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2
'Softplus': ['softplus', ['X'], ['Out']],
# FIXME: default axis = -1, reshape required before and after
'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')],
'Softsign': ['softsign', ['X'], ['Out']],
'Sqrt': ['sqrt', ['X'], ['Out']],
'Tanh': ['tanh', ['X'], ['Out']],
'ThresholdedRelu': ['thresholded_relu', ['X'], ['Out'], dict(alpha='threshold')],
#'Transpose': ['transpose', ['X'], ['Out']],
'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2
## binary ops ##
'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
#'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
'And': ['logical_and', ['X', 'Y'], ['Out']],
'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), [1, 0], None, False],
'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x vs transpose_X
'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Not': ['logical_not', ['X', 'Y'], ['Out']],
'OneHot': # assuming values=[0, 1], axis=-1 and drop them
['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(),
[0, 1], None, False],
'Or': ['logical_or', ['X', 'Y'], ['Out']],
'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], # TODO: pow for scalar exponent
'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Xor': ['logical_xor', ['X', 'Y'], ['Out']],
# reduce ops
'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceMean': ['reduce_mean', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceMin': ['reduce_min', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceProd': ['reduce_prod', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceSum': ['reduce_sum', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
# other ops
'Scatter': ['scatter', ['X', 'Index', 'Updates'], ['Out']],
'TopK': ['topk', ['X', 'K'], ['Out', 'Indices']],
}
DEFAULT_IOA_CONSTRAINTS = {
......@@ -146,14 +146,14 @@ DEFAULT_IOA_CONSTRAINTS = {
def _make_var_name(name):
"""
make a valid variable name in Python code and in filesystem
"""
make a valid variable name in Python code and in filesystem
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:': #
for s in ' \\|/:-': #
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
......@@ -191,14 +191,24 @@ def _const_weight_or_none(value_infos, val_name):
return None
value_info = value_infos[val_name]
const_value = value_info.get('const_value', None)
if const_value:
if const_value is not None:
return const_value
get_weight_func = value_info.get('get_weight', None)
if get_weight_func:
if get_weight_func is not None:
return get_weight_func()
return None
def _check_embeddable(value_infos, *val_names):
keyword = 'get_weight'
for val_name in val_names:
if keyword not in value_infos[val_name]:
_logger.warning('parameter %s not embeddable for some ops',
val_name)
return False
return True
def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
info = DEFAULT_OP_MAPPING[op_type]
info.extend(DEFAULT_OP_MAPPING_VALUES[len(info):])
......@@ -391,9 +401,9 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
input_shape = _shape_or_none(value_infos, val_x)
output_shape = _shape_or_none(value_infos, val_y)
assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC...
if input_shape:
if input_shape is not None:
poolnd = len(input_shape) - 2 # NC...
elif output_shape:
elif output_shape is not None:
poolnd = len(output_shape) - 2 # NC...
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
......@@ -568,7 +578,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
1] == 1, 'only scale on (NC)HW supported'
assert scales[2] == scales[
3], 'only aspect-ratio-invariant scale supported'
scale = scales[2] if scales else None
scale = None if scales is None else scales[2]
# try input shape
if scale is None:
assert out_shape_, 'neither scales nor output shape is available'
......@@ -613,24 +623,24 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
def AdaptiveAveragePool(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
aten::adaptive_avg_poolnd
"""
aten::adaptive_avg_poolnd
"""
return _adaptive_pool(prog, 'avg', inputs, outputs, attrs, name=name)
def AdaptiveMaxPool(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
aten::adaptive_max_poolnd
"""
aten::adaptive_max_poolnd
"""
return _adaptive_pool(prog, 'max', inputs, outputs, attrs, name=name)
def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
aten::affine_grid
"""
aten::affine_grid
"""
# I/O
val_theta, = inputs
......@@ -672,8 +682,8 @@ def AveragePool(prog,
*args,
**kwargs):
"""
onnx::AveragePool-10:
"""
onnx::AveragePool-10:
"""
return _pool(prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
......@@ -688,8 +698,8 @@ def BatchNormalization(prog,
*args,
**kwargs):
"""
onnx::BatchNormalization-9:
"""
onnx::BatchNormalization-9:
"""
# I/O
val_x, val_scale, val_b, val_mean, val_var = inputs
......@@ -704,16 +714,19 @@ def BatchNormalization(prog,
momentum = attrs.get('momentum', .9) # optional
epsilon = attrs.get('epsilon', 1e-5) # optional
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
embed_params = _check_embeddable(value_infos, val_scale, val_b,
val_mean, val_var)
if embed_params:
assert name != ''
var_scale = name + '.w_0'
var_b = name + '.b_0'
var_mean = name + '.w_1'
var_var = name + '.w_2'
value_infos[val_scale].setdefault('embeded_as', []).append(var_scale)
value_infos[val_b].setdefault('embeded_as', []).append(var_b)
value_infos[val_mean].setdefault('embeded_as', []).append(var_mean)
value_infos[val_var].setdefault('embeded_as', []).append(var_var)
value_infos[val_scale]['embeded_as'].append(var_scale)
value_infos[val_b]['embeded_as'].append(var_b)
value_infos[val_mean]['embeded_as'].append(var_mean)
value_infos[val_var]['embeded_as'].append(var_var)
param_attr = ''
else:
var_scale = _make_var_name(val_scale)
......@@ -760,8 +773,8 @@ def BatchNormalization(prog,
def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::Cast-9:
"""
onnx::Cast-9:
"""
# I/O
val_input, = inputs
......@@ -774,7 +787,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
if not isinstance(dtype, _np.dtype): # additional: possible np.dtype
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype:
if output_dtype is not None:
assert dtype == output_dtype, 'dtype of to unmatches output'
fluid_op = 'cast'
......@@ -804,8 +817,8 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
onnx::Concat-4:
"""
onnx::Concat-4:
"""
# I/O
val_concat_result, = outputs
......@@ -839,11 +852,11 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::Constant-9:
"""
onnx::Constant-9:
"""
# I/O
assert len(inputs) == 0
assert len(inputs) == 0, 'constant op accept no inputs'
val_output, = outputs
var_output = _make_var_name(val_output)
......@@ -851,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
value = attrs['value'] # required
dtype = _np.dtype(value.dtype)
output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype:
if output_dtype is not None:
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
......@@ -900,8 +913,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::ConstantOfShape-9:
"""
onnx::ConstantOfShape-9:
"""
# I/O
val_shape, = inputs
......@@ -939,8 +952,8 @@ def Conv(prog,
*args,
**kwargs):
"""
onnx::Conv-1:
"""
onnx::Conv-1:
"""
# I/O
val_x, val_w = inputs[:2]
......@@ -970,13 +983,16 @@ def Conv(prog,
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
embed_params = (_check_embeddable(value_infos, val_w) and not has_bias
or _check_embeddable(value_infos, val_b))
if embed_params:
assert name != ''
var_w = name + '.w_0'
value_infos[val_w].setdefault('embeded_as', []).append(var_w)
value_infos[val_w]['embeded_as'].append(var_w)
if has_bias:
var_b = name + '.b_0'
value_infos[val_b].setdefault('embeded_as', []).append(var_b)
value_infos[val_b]['embeded_as'].append(var_b)
param_attr = ''
else:
param_attr = ', bias_attr=False'
......@@ -1046,8 +1062,8 @@ def ConvTranspose(prog,
*args,
**kwargs):
"""
onnx::ConvTranspose-1:
"""
onnx::ConvTranspose-1:
"""
# I/O
val_x, val_w = inputs[:2]
......@@ -1080,13 +1096,16 @@ def ConvTranspose(prog,
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
embed_params = (_check_embeddable(value_infos, val_w) and not has_bias
or _check_embeddable(value_infos, val_b))
if embed_params:
assert name != ''
var_w = name + '.w_0'
value_infos[val_w].setdefault('embeded_as', []).append(var_w)
value_infos[val_w]['embeded_as'].append(var_w)
if has_bias:
var_b = name + '.b_0'
value_infos[val_b].setdefault('embeded_as', []).append(var_b)
value_infos[val_b]['embeded_as'].append(var_b)
param_attr = ''
else:
param_attr = ', bias_attr=False'
......@@ -1167,8 +1186,8 @@ def ConvTranspose(prog,
def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
onnx::Gemm-9:
"""
onnx::Gemm-9:
"""
# due to fluid fc don't support transposed weight, we use matmul + ew_add
val_a, val_b, val_c = inputs
......@@ -1259,8 +1278,8 @@ def GlobalAveragePool(prog,
*args,
**kwargs):
"""
onnx::GlobalAveragePool-1:
"""
onnx::GlobalAveragePool-1:
"""
return _global_pool(prog,
'avg',
......@@ -1280,8 +1299,8 @@ def GlobalMaxPool(prog,
*args,
**kwargs):
"""
onnx::GlobalMaxPool-1:
"""
onnx::GlobalMaxPool-1:
"""
return _global_pool(prog,
'max',
......@@ -1295,8 +1314,8 @@ def GlobalMaxPool(prog,
def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args,
**kwargs):
"""
onnx::MaxPool-10:
"""
onnx::MaxPool-10:
"""
return _pool(prog, 'max', inputs, outputs, attrs, value_infos, name=name)
......@@ -1304,16 +1323,16 @@ def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args,
def MaxRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args,
**kwargs):
"""
onnx::MaxRoiPool-1:
"""
onnx::MaxRoiPool-1:
"""
_roi_pool(prog, 'roi_pool', inputs, outputs, attrs, value_infos, name)
def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
"""
onnx::Pad-2:
"""
onnx::Pad-2:
"""
# I/O
val_data, = inputs
......@@ -1330,9 +1349,9 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
assume_pad2d = False
if len(pads) == 4:
assume_pad2d |= mode != 'constant'
if data_shape:
if data_shape is not None:
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
if output_shape:
if output_shape is not None:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
od_attrs = {'pad_value': value}
if assume_pad2d:
......@@ -1383,8 +1402,8 @@ def PRelu(prog,
*args,
**kwargs):
"""
onnx::PRelu-9:
"""
onnx::PRelu-9:
"""
# I/O
val_x, val_slope = inputs
......@@ -1404,10 +1423,12 @@ def PRelu(prog,
mode = 'element'
fluid_op = 'prelu'
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
embed_params = _check_embeddable(value_infos, val_slope)
if embed_params:
assert name != ''
var_slope = name + '.w_0'
value_infos[val_slope].setdefault('embeded_as', []).append(var_slope)
value_infos[val_slope]['embeded_as'].append(var_slope)
param_attr = ''
else:
var_slope = _make_var_name(val_slope)
......@@ -1436,16 +1457,16 @@ def PRelu(prog,
def PsRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
caffe2::PsRoiPool
"""
caffe2::PsRoiPool
"""
_roi_pool(prog, 'psroi_pool', inputs, outputs, attrs, value_infos, name)
def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
onnx::Reshape-5:
"""
onnx::Reshape-5:
"""
# I/O
val_data, val_shape = inputs
......@@ -1474,6 +1495,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
val_shape_int32 = val_shape + '_int32' # explicit variable
var_shape_int32 = _make_var_name(val_shape_int32)
prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape))
if is_const_shape:
prog.Code('{} = layers.{}({}'
......@@ -1487,8 +1510,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr,
))
else:
val_shape_int32 = val_shape + '_int32' # explicit variable
var_shape_int32 = _make_var_name(val_shape_int32)
prog.Op(
'',
'Cast',
......@@ -1514,34 +1535,26 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
var_xshape = name + '.xshape' # dummy output
prog.VarDesc(var_reshaped)
prog.VarDesc(var_xshape)
if is_const_shape:
prog.OpDesc(
fluid_op,
([var_data], 'X'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
{'shape': shape},
)
else:
prog.OpDesc(
fluid_op,
([var_data, var_shape_int32], 'X', 'Shape'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
{'shape': shape},
)
prog.OpDesc(
fluid_op,
([var_data, var_shape_int32], 'X', 'Shape'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
{'shape': shape},
)
def Resize(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
"""
onnx::Resize-10:
"""
onnx::Resize-10:
"""
return _interpolate(prog, inputs, outputs, attrs, value_infos, name=name)
def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
caffe2::RoiAlign
"""
caffe2::RoiAlign
"""
_roi_pool(prog, 'roi_align', inputs, outputs, attrs, value_infos, name)
......@@ -1580,8 +1593,8 @@ def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::Slice-1:9
"""
onnx::Slice-1:9
"""
# I/O
val_data, = inputs
......@@ -1595,7 +1608,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
starts = attrs['starts'] # required
ends = attrs['ends'] # required
shape = _shape_or_none(value_infos, val_data)
if shape:
if shape is not None:
# ndims = len(shape)
# for idx, value in enumerate(axes):
# if value > ONNX_INT_MAX // 2:
......@@ -1639,8 +1652,8 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
onnx::Split-2:
"""
onnx::Split-2:
"""
# I/O
val_input, = inputs
......@@ -1680,8 +1693,8 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
def Sum(prog, inputs, outputs, *args, **kwargs):
"""
onnx::Sum-8:
"""
onnx::Sum-8:
"""
# I/O
val_sum, = outputs
......@@ -1710,8 +1723,8 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
"""
onnx::Tile-1:
"""
onnx::Tile-1:
"""
# I/O
val_input, val_repeats = inputs
......@@ -1749,8 +1762,8 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
onnx::Transpose-1:
"""
onnx::Transpose-1:
"""
# I/O
val_data, = inputs
......@@ -1795,8 +1808,8 @@ def Upsample(prog,
*args,
**kwargs):
"""
onnx::Upsample-9:9
"""
onnx::Upsample-9:9
"""
return _interpolate(prog, inputs, outputs, attrs, value_infos, name=name)
......
......@@ -25,7 +25,8 @@ def ensure_tuple(obj):
def flatten_list(obj, out=None):
assert isinstance(obj, list)
assert isinstance(obj, list), 'list type required'
if out is None:
out = type(obj)()
for item in obj:
......@@ -38,11 +39,11 @@ def flatten_list(obj, out=None):
def export_data(state_dict, prefix=''):
"""
export binary data with meta text for raw C++ inference engines
"""
export binary data with meta text for raw C++ inference engines
"""
def str_(obj):
if isinstance(obj, (tuple, list)):
if isinstance(obj, (tuple, list, set)):
return str(obj)[1:-1].replace(' ', '')
return str(obj)
......@@ -72,8 +73,8 @@ def export_onnx_with_validation(model,
*args,
**kwargs):
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
is_tuple_or_list = lambda x: isinstance(x, (tuple, list))
......
......@@ -10,14 +10,15 @@ import importlib, logging, os, sys
def flatten_dict(obj, out=None):
assert isinstance(obj, dict)
assert isinstance(obj, dict), 'dict type required'
if out is None:
out = type(obj)()
for key, value in obj.items():
if isinstance(value, dict):
flatten_dict(value, out)
else:
assert key not in out
assert key not in out, 'key conflicted'
out[key] = value
return out
......@@ -29,15 +30,16 @@ def ensure_list(obj):
def validate(fluid_model_filename,
golden_data_filename,
model_func_name='inference',
golden_data_filename='',
atol=1e-3,
rtol=1e-3,
model_func_name='inference',
save_inference_model=False,
inference_input_names=None,
**kwargs):
"""
inference the converted Paddle fluid model, validate with given golden data
"""
inference the converted Paddle fluid model, validate with given golden data
"""
import numpy as np
import paddle.fluid as fluid
......@@ -86,24 +88,50 @@ def validate(fluid_model_filename,
raise ValueError('unsupported Paddle fluid model filename')
# load data
logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename, encoding='bytes')
input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist()
if golden_data_filename:
logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename, encoding='bytes')
input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist()
else:
test_data = np.load(golden_data_filename, encoding='bytes').tolist()
input_data = test_data['inputs']
output_data = test_data['outputs']
input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data)
input_names = input_data.keys()
logger.info('found %d I/O golden data, starting test ...',
len(input_data) + len(output_data))
else:
test_data = np.load(golden_data_filename, encoding='bytes').tolist()
input_data = test_data['inputs']
output_data = test_data['outputs']
input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data)
logger.info('found %d I/O golden data, starting test ...',
len(input_data) + len(output_data))
# DEBUG: reload test for Python code
if basename.endswith('.py') and save_inference_model:
assert inference_input_names, 'input names required for type-shape inference'
input_names = inference_input_names
logger.info('using input names: %s', ', '.join(input_names))
# type-shape inference and re-save
if save_inference_model:
for block in prog.blocks:
block_desc = block.desc
for idx_op in range(block_desc.op_size()):
op_desc = block_desc.op(idx_op)
if op_desc.type() in ('feed', 'fetch'):
continue
op_desc.infer_var_type(block_desc)
op_desc.infer_shape(block_desc)
for var_name, var in block.vars.items():
var_desc = var.desc
if var_desc.type() != fluid.core.VarDesc.VarType.LOD_TENSOR:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try:
var.to_string(True)
except ValueError:
var_desc.set_dtype(fluid.core.VarDesc.VarType.FP32)
fluid.io.save_inference_model(fluid_model_dir,
input_data.keys(),
input_names,
var_outs,
exe,
main_program=prog,
......@@ -112,8 +140,12 @@ def validate(fluid_model_filename,
fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed')
if not golden_data_filename:
return True
# execute
outputs = exe.run(prog, feed=input_data, fetch_list=out_names)
outputs = exe.run(prog, feed=input_data,
fetch_list=out_names) # out_names can be vars
logger.info('execution passed')
# validate
......@@ -134,11 +166,10 @@ def validate(fluid_model_filename,
logger.info('accuracy passed')
else:
logger.info('accuracy not passed')
return passed
if __name__ == '__main__':
def main():
import argparse
parser = argparse.ArgumentParser(
......@@ -160,6 +191,7 @@ if __name__ == '__main__':
'--test_data',
'-t',
type=str,
default='',
help='I/O golden data for validation, e.g. test.npy, test.npz',
)
parser.add_argument(
......@@ -175,19 +207,36 @@ if __name__ == '__main__':
default=1e-2,
help='assertion relative tolerance for validation',
)
parser.add_argument(
'--infer_inputs',
'-i',
nargs='?',
default=None,
const='',
help=
'perform type-shape inference with given input names and re-save model',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level)
debug = args.debug
# debug = args.debug
fluid_model_filename = args.model[0]
golden_data_filename = args.test_data
atol, rtol = args.atol, args.rtol
save_inference_model = args.infer_inputs is not None
inference_input_names = args.infer_inputs.split(
',') if args.infer_inputs else None
validate(fluid_model_filename,
golden_data_filename,
golden_data_filename=golden_data_filename,
atol=atol,
rtol=rtol,
save_inference_model=debug)
save_inference_model=save_inference_model,
inference_input_names=inference_input_names)
if __name__ == '__main__':
main()
......@@ -44,6 +44,8 @@ def irepr(obj, to='_'):
def flatten_list(obj, out=None):
assert isinstance(obj, list), 'list type required'
if out is None:
out = type(obj)()
for item in obj:
......@@ -56,12 +58,12 @@ def flatten_list(obj, out=None):
def make_attr_name(name):
"""
make a valid code name for ParamAttr
"""
make a valid code name for ParamAttr
"""
if name == '':
raise ValueError('name should not be empty')
for s in ' \\|/:': #
assert name != '', 'name should not be empty'
for s in ' \\|/:-': #
name = name.replace(s, '_')
if not name.startswith('_'):
name = '_' + name
......@@ -70,8 +72,8 @@ def make_attr_name(name):
class Program(object):
"""
fluid Python code and ProgramDesc wrapper
"""
fluid Python code and ProgramDesc wrapper
"""
DTYPE_TO_FRAMEWORK_DTYPE = {
'bool': framework_pb2.VarType.BOOL,
......@@ -88,8 +90,8 @@ class Program(object):
@staticmethod
def Dtype(dtype):
"""
convert dtype to fulid framework dtype
"""
convert dtype to fulid framework dtype
"""
dtype = np.dtype(dtype).name
return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype]
......@@ -97,8 +99,8 @@ class Program(object):
@staticmethod
def OpDescVars(vals, *keys):
"""
make (OpDesc.Var)s
"""
make (OpDesc.Var)s
"""
od_vars = []
for idx, key in enumerate(keys):
......@@ -112,8 +114,8 @@ class Program(object):
@staticmethod
def OpDescAttrs(attrs):
"""
make (OpDesc.Attr)s
"""
make (OpDesc.Attr)s
"""
od_attrs = []
for key, value in attrs.items():
......@@ -178,8 +180,8 @@ class Program(object):
def Code(self, code):
"""
add Python code
"""
add Python code
"""
if self.code_mutable:
self.codes.append(code)
......@@ -190,16 +192,16 @@ class Program(object):
output_val_keys=None,
attrs=None):
"""
add OpDesc
"""
add OpDesc
"""
desc = framework_pb2.OpDesc()
desc.type = op_type
if input_val_keys is not None:
if input_val_keys:
desc.inputs.extend(self.OpDescVars(*input_val_keys))
if output_val_keys is not None:
if output_val_keys:
desc.outputs.extend(self.OpDescVars(*output_val_keys))
if attrs is not None:
if attrs:
desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc)
return desc
......@@ -210,8 +212,8 @@ class Program(object):
value_info=None,
remove_batch=None):
"""
add VarDesc,
"""
add VarDesc,
"""
assert var_name not in self.var_descs, 'var naming conflicted'
......@@ -220,13 +222,16 @@ class Program(object):
var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[var_name] = var_desc
if value_info:
self.VarTypeInfo(var_name, value_info, remove_batch=remove_batch)
self.VarTypeShapeInfo(var_name,
value_info,
remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs):
"""
convert an ONNX op and add it to program
"""
convert an ONNX op and add it to program
"""
if domain != '': # TODO: symbolic file routing by domain
raise ValueError('only default domain supported')
......@@ -242,8 +247,8 @@ class Program(object):
def IntermediateOp(self, domain, op_type, *args, **kwargs):
"""
convert an intermediate ONNX op declaring in desc program only
"""
convert an intermediate ONNX op declaring in desc program only
"""
code_mutable = self.code_mutable
self.code_mutable = False
......@@ -255,10 +260,10 @@ class Program(object):
else:
self.code_mutable = code_mutable
def VarTypeInfo(self, var_name, value_info, remove_batch=None):
def VarTypeShapeInfo(self, var_name, value_info, remove_batch=None):
"""
set value_info for var
"""
set value_info for var
"""
if var_name not in self.var_descs:
return
......@@ -284,8 +289,8 @@ class Program(object):
class Writer(object):
"""
fluid code and desc writter
"""
fluid code and desc writter
"""
# CODE_INDENT = ' ' * 4
CODE_INDENT = '\t'
......@@ -293,8 +298,8 @@ class Writer(object):
@staticmethod
def header_code(func_name, info=''):
"""
Python header codes
"""
Python header codes
"""
codes = []
codes.append('"""')
......@@ -315,8 +320,8 @@ class Writer(object):
def emit_op(prog, name, domain, op_type, inputs, outputs, attrs,
value_infos, *args, **kwargs):
"""
emit an ONNX op into program
"""
emit an ONNX op into program
"""
prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type,
inputs, outputs,
......@@ -334,8 +339,8 @@ class Writer(object):
@staticmethod
def emit_param(prog, name, value_info):
"""
emit an ONNX weight into program
"""
emit an ONNX weight into program
"""
if value_info.get('embeded_as', []):
var_names = value_info['embeded_as']
......@@ -359,8 +364,8 @@ class Writer(object):
@staticmethod
def emit_inputs(prog, names, value_infos, remove_batch=None):
"""
emit ONNX inputs into program
"""
emit ONNX inputs into program
"""
for idx, name in enumerate(names):
var_name = make_var_name(name)
......@@ -396,8 +401,8 @@ class Writer(object):
@staticmethod
def emit_outputs(prog, names): #, value_infos
"""
emit ONNX outputs into program
"""
emit ONNX outputs into program
"""
code = 'return '
for idx, name in enumerate(names):
......@@ -416,8 +421,8 @@ class Writer(object):
@staticmethod
def add_codes(codes, others, indent):
"""
flatten codes in program
"""
flatten codes in program
"""
for code in flatten_list(others):
codes.append(Writer.CODE_INDENT * indent + code)
......@@ -426,11 +431,10 @@ class Writer(object):
@staticmethod
def write_weight(weight, filename):
"""
write single weight in fluid desc
"""
write single weight in fluid desc
"""
if not isinstance(weight, np.ndarray):
raise TypeError('weight is not an ndarray')
assert isinstance(weight, np.ndarray), 'weight is not an ndarray'
tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype)
......@@ -448,12 +452,11 @@ class Writer(object):
@staticmethod
def write_weights(weights, save_dir):
"""
write multiple weights in each fluid desc
"""
write multiple weights in each fluid desc
"""
for name, weight in weights.items():
if not isinstance(weights, dict):
raise TypeError('dict type weights required')
assert isinstance(weights, dict), 'dict type weights required'
var_name = make_var_name(name)
filename = os.path.join(save_dir, var_name)
......@@ -463,8 +466,8 @@ class Writer(object):
@staticmethod
def write_code_file(filename, header_code, *body_codes):
"""
write Python code to file
"""
write Python code to file
"""
codes = []
Writer.add_codes(codes, header_code, 0)
......@@ -481,8 +484,8 @@ class Writer(object):
@staticmethod
def write_desc_file(filename, op_descs, var_descs):
"""
write desc program to file
"""
write desc program to file
"""
prog_desc = framework_pb2.ProgramDesc()
block_desc = prog_desc.blocks.add()
......
......@@ -54,8 +54,8 @@ zip_safe = True
[options.entry_points]
console_scripts =
onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion
onnx2fluid_validate = onnx2fluid.validation
onnx2fluid_convert = onnx2fluid.conversion:main
onnx2fluid_validate = onnx2fluid.validation:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册