提交 ba40d265 编写于 作者: M Macrobull

optimize symbolic

上级 52a502df
...@@ -69,6 +69,22 @@ parser.add_argument( ...@@ -69,6 +69,22 @@ parser.add_argument(
dest='pedantic', dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails', help='process non-standard ONNX ops, this may lead to fails',
) )
parser.add_argument(
'--skip-version-conversion',
'-y',
action='store_true',
default=False,
help='skip ONNX op version conversion, workaround for RumtimeErrors',
)
parser.add_argument(
'--archive',
'-z',
nargs='?',
type=str,
default=None,
const='',
help='compress outputs to ZIP file if conversion successed',
)
parser.add_argument( parser.add_argument(
'--precision', '--precision',
'-p', '-p',
......
...@@ -16,10 +16,10 @@ from __future__ import division ...@@ -16,10 +16,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
# import logging, shutil, zipfile import logging, shutil, zipfile
import logging #import logging
import shutil #import shutil
import zipfile #import zipfile
__all__ = [ __all__ = [
'main', 'main',
...@@ -49,12 +49,14 @@ def main(**kwargs): ...@@ -49,12 +49,14 @@ def main(**kwargs):
basepath, _ = shutil.os.path.splitext(filename) basepath, _ = shutil.os.path.splitext(filename)
save_dir = kwargs.get('output_dir', '') save_dir = kwargs.get('output_dir', '')
# model.onnx -> model/ # model.onnx -> model/
save_dir = shutil.os.path.dirname(save_dir) if save_dir else basepath save_dir = (save_dir.rstrip('/') if save_dir else basepath) + '/'
model_basename = DEFAULT_MODEL_MODULE + '.py' model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC model_func_name = DEFAULT_MODEL_FUNC
embed_params = kwargs.get('embed_params', False) embed_params = kwargs.get('embed_params', False)
onnx_opset_version = DEFAULT_ONNX_OPSET_VERSION onnx_opset_version = DEFAULT_ONNX_OPSET_VERSION
onnx_opset_pedantic = kwargs.get('pedantic', True) onnx_opset_pedantic = kwargs.get('pedantic', True)
onnx_skip_version_conversion = kwargs.get('skip_version_conversion', False)
archive = kwargs.get('archive', None)
# convert # convert
convert( convert(
...@@ -65,6 +67,7 @@ def main(**kwargs): ...@@ -65,6 +67,7 @@ def main(**kwargs):
embed_params=embed_params, embed_params=embed_params,
onnx_opset_version=onnx_opset_version, onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=onnx_opset_pedantic, onnx_opset_pedantic=onnx_opset_pedantic,
onnx_skip_version_conversion=onnx_skip_version_conversion,
debug=debug) debug=debug)
# validate # validate
...@@ -104,10 +107,18 @@ def main(**kwargs): ...@@ -104,10 +107,18 @@ def main(**kwargs):
return return
# create zip file # create zip file
fn_zip = save_dir.rstrip('/') + '.zip' if archive is not None:
logger.info('compressing file to %s ...', fn_zip) if archive == '':
fz = zipfile.ZipFile(fn_zip, 'w', compression=zipfile.ZIP_LZMA) archive = save_dir.rstrip('/') + '.zip'
for fn in shutil.os.listdir(save_dir): logger.info('compressing file to %s ...', archive)
shutil.sys.stderr.write('\n')
shutil.sys.stderr.flush()
file_list = shutil.os.listdir(save_dir)
fz = zipfile.ZipFile(archive, 'w', compression=zipfile.ZIP_LZMA)
for idx, fn in enumerate(file_list):
shutil.sys.stderr.write('\033[F\033[2K')
logger.info('file {}/{}: {}'.format(idx + 1, len(file_list), fn))
shutil.sys.stderr.flush()
fz.write(shutil.os.path.join(save_dir, fn), arcname=fn) fz.write(shutil.os.path.join(save_dir, fn), arcname=fn)
fz.close() fz.close()
logger.info('compressing done') logger.info('compressing done')
...@@ -132,5 +143,6 @@ if __name__ == '__main__': ...@@ -132,5 +143,6 @@ if __name__ == '__main__':
output_dir='/tmp/export/', output_dir='/tmp/export/',
embed_params=True, embed_params=True,
pedantic=False, pedantic=False,
skip_version_conversion=False,
test_data='../examples/inception_v2/test_data_set_2.npz', test_data='../examples/inception_v2/test_data_set_2.npz',
debug=True) debug=True)
...@@ -8,9 +8,9 @@ Created on Mon Feb 25 09:50:35 2019 ...@@ -8,9 +8,9 @@ Created on Mon Feb 25 09:50:35 2019
from __future__ import division from __future__ import division
# import logging, shutil import logging, shutil
import logging #import logging
import shutil #import shutil
__all__ = [ __all__ = [
'convert', 'convert',
...@@ -24,6 +24,7 @@ def convert(onnx_model_filename, ...@@ -24,6 +24,7 @@ def convert(onnx_model_filename,
embed_params=False, embed_params=False,
onnx_opset_version=9, onnx_opset_version=9,
onnx_opset_pedantic=True, onnx_opset_pedantic=True,
onnx_skip_version_conversion=False,
debug=False): debug=False):
""" """
convert an ONNX model to Paddle fluid Python code and desc pb convert an ONNX model to Paddle fluid Python code and desc pb
...@@ -60,12 +61,13 @@ def convert(onnx_model_filename, ...@@ -60,12 +61,13 @@ def convert(onnx_model_filename,
try: try:
logger.info('checking model ...') logger.info('checking model ...')
check_model(onnx_model) check_model(onnx_model)
logger.debug('using opset version: %d', onnx_opset_version) if onnx_skip_version_conversion: # WORKAROUND: RuntimeError: No Adapter For OP
if onnx_opset_pedantic: # WORKAROUND: RuntimeError: No Adapter For OP logger.debug('assumed opset version: %d', onnx_opset_version)
onnx_model = convert_version(onnx_model, onnx_opset_version)
else: # TODO: add new argument for this option
logger.warning( logger.warning(
'opset conversion skipped for onnx_opset_pedantic is OFF') 'opset conversion skipped for onnx_opset_pedantic is OFF')
else:
logger.debug('using opset version: %d', onnx_opset_version)
onnx_model = convert_version(onnx_model, onnx_opset_version)
onnx_model = polish_model(onnx_model) onnx_model = polish_model(onnx_model)
except ValidationError as e: except ValidationError as e:
if onnx_opset_pedantic: if onnx_opset_pedantic:
...@@ -152,16 +154,15 @@ def convert(onnx_model_filename, ...@@ -152,16 +154,15 @@ def convert(onnx_model_filename,
logger.info( logger.info(
'weight %s is shared between ops, more disk space will be consumed', 'weight %s is shared between ops, more disk space will be consumed',
name) name)
logger.debug( logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name,
'saving weight %s with size of %d, in %d bytes, as %s ...', weight.dtype, weight.size, weight.nbytes, var_names)
name, weight.size, weight.nbytes, var_names)
for var_name in var_names: # multiple references for var_name in var_names: # multiple references
fluid_writer.write_weight( fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, var_name)) weight, shutil.os.path.join(save_dir, var_name))
else: else:
logger.debug( logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name,
'saving weight %s with size of %d, in %d bytes, to %s ...', weight.dtype, weight.size, weight.nbytes,
name, weight.size, weight.nbytes, make_var_name(name)) make_var_name(name))
fluid_writer.write_weight( fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, make_var_name(name))) weight, shutil.os.path.join(save_dir, make_var_name(name)))
fluid_writer.emit_param(fluid_program, name, value_info) fluid_writer.emit_param(fluid_program, name, value_info)
...@@ -262,6 +263,13 @@ if __name__ == '__main__': ...@@ -262,6 +263,13 @@ if __name__ == '__main__':
dest='pedantic', dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails', help='process non-standard ONNX ops, this may lead to fails',
) )
parser.add_argument(
'--skip-version-conversion',
'-y',
action='store_true',
default=False,
help='skip ONNX op version conversion, workaround for RumtimeErrors',
)
args = parser.parse_args() args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
...@@ -273,10 +281,12 @@ if __name__ == '__main__': ...@@ -273,10 +281,12 @@ if __name__ == '__main__':
save_dir = args.output_dir save_dir = args.output_dir
embed_params = args.embed_params embed_params = args.embed_params
pedantic = args.pedantic pedantic = args.pedantic
skip_version_conversion = args.skip_version_conversion
convert( convert(
model_filename, model_filename,
save_dir, save_dir,
embed_params=embed_params, embed_params=embed_params,
onnx_opset_pedantic=pedantic, onnx_opset_pedantic=pedantic,
onnx_skip_version_conversion=skip_version_conversion,
debug=debug) debug=debug)
...@@ -26,6 +26,7 @@ __all__ = [ ...@@ -26,6 +26,7 @@ __all__ = [
'node_attrs', 'node_attrs',
'node_topo', 'node_topo',
'node_iter', 'node_iter',
'tensor_dtype',
'tensor_shape', 'tensor_shape',
'graph_ops', 'graph_ops',
'graph_weights', 'graph_weights',
...@@ -92,13 +93,12 @@ def get_attribute_value2(attr): ...@@ -92,13 +93,12 @@ def get_attribute_value2(attr):
return value return value
def node_attrs(node): def tensor_dtype(tensor):
""" """
convert ONNX node attributes to dict get ONNX tensor in np.dtype
""" """
return {attr.name: get_attribute_value2(attr) return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
for attr in node.attribute} # dict
def tensor_shape(tensor): def tensor_shape(tensor):
...@@ -109,6 +109,15 @@ def tensor_shape(tensor): ...@@ -109,6 +109,15 @@ def tensor_shape(tensor):
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]
def node_attrs(node):
"""
convert ONNX node attributes to dict
"""
return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
def node_topo(nodes, topo='default'): def node_topo(nodes, topo='default'):
""" """
build indices with given topology to an ONNX node graph build indices with given topology to an ONNX node graph
...@@ -237,21 +246,21 @@ def inferred_model_value_info(model): ...@@ -237,21 +246,21 @@ def inferred_model_value_info(model):
value_info = Dict() value_info = Dict()
for item in graph.value_info: for item in graph.value_info:
value_info[item.name] = dict( value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type], dtype=tensor_dtype(item),
shape=tensor_shape(item), shape=tensor_shape(item),
external=False, external=False,
) )
for item in graph.input: for item in graph.input:
assert item.name not in value_info assert item.name not in value_info
value_info[item.name] = dict( value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type], dtype=tensor_dtype(item),
shape=tensor_shape(item), shape=tensor_shape(item),
external=True, external=True,
) )
for item in graph.output: for item in graph.output:
# assert item.name not in value_info, 'bypass-model not supported' # assert item.name not in value_info, 'bypass-model not supported'
value_info[item.name] = dict( value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type], dtype=tensor_dtype(item),
shape=tensor_shape(item), shape=tensor_shape(item),
external=True, external=True,
) )
...@@ -373,9 +382,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True): ...@@ -373,9 +382,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
elif not keep_input_only and name in output_refs: elif not keep_input_only and name in output_refs:
ret_initializers.add().CopyFrom(initializer) ret_initializers.add().CopyFrom(initializer)
else: else:
logger.debug('initializer %s(%s[%d]) stripped', name, dtype = TENSOR_TYPE_TO_NP_TYPE[initializer.data_type]
TENSOR_TYPE_TO_NP_TYPE[initializer.data_type], logger.debug('initializer %s(%s[%d]) stripped', name, dtype,
len(initializer.raw_data)) len(initializer.raw_data) // dtype.itemsize)
# strip inputs # strip inputs
ret.graph.ClearField('input') ret.graph.ClearField('input')
...@@ -385,9 +394,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True): ...@@ -385,9 +394,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
if name in input_refs or name in out_names: if name in input_refs or name in out_names:
ret_inputs.add().CopyFrom(item) ret_inputs.add().CopyFrom(item)
else: else:
logger.debug( logger.debug('input %s(%s%s) stripped', name, tensor_dtype(item),
'input %s(%s%s) stripped', name,
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
tensor_shape(item)) tensor_shape(item))
return ret return ret
......
...@@ -19,7 +19,7 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE ...@@ -19,7 +19,7 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
_logger = _logging.getLogger(__name__) _logger = _logging.getLogger(__name__)
ONNX_INT_MAX = 2**63 - 1 ONNX_INT_MAX = 2**63 - 1
FLUID_INT_MAX = 2**31 - 1 FLUID_INT_MAX = 2**31 - 1 #
DEFAULT_ONNX_OP_DOMAIN = '' DEFAULT_ONNX_OP_DOMAIN = ''
DEFAULT_FLUID_OP_NAMESCOPE = '/' DEFAULT_FLUID_OP_NAMESCOPE = '/'
...@@ -186,13 +186,17 @@ def _shape_or_none(value_infos, val_name): ...@@ -186,13 +186,17 @@ def _shape_or_none(value_infos, val_name):
return list(value_info['shape']) return list(value_info['shape'])
#def _maybe_const_value(value_infos, val_name): def _const_weight_or_none(value_infos, val_name):
# var_name = _make_var_name(val_name) if val_name not in value_infos:
# if val_name not in value_infos: return None
# return var_name value_info = value_infos[val_name]
# value_info = value_infos[val_name] const_value = value_info.get('const_value', None)
# assert value_info.get('remove_batch', False) == False, 'const value should not have batch dim' if const_value:
# return value_info.get('const_value', var_name) return const_value
get_weight_func = value_info.get('get_weight', None)
if get_weight_func:
return get_weight_func()
return None
def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
...@@ -253,7 +257,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -253,7 +257,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
num_vars = len(var_outs) num_vars = len(var_outs)
num_args = len(fluid_output_args) num_args = len(fluid_output_args)
if num_vars < num_args: if num_vars < num_args:
assert fill_name_field, 'name required to naming dummy output variable' assert fill_name_field, 'name required to name dummy output variables'
for idx_out in range(num_vars, num_args): for idx_out in range(num_vars, num_args):
var_out = _make_var_name(name + '.' + var_out = _make_var_name(name + '.' +
fluid_output_args[idx_out].lower()) fluid_output_args[idx_out].lower())
...@@ -294,9 +298,8 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE ...@@ -294,9 +298,8 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
if pads[idx_dim] != pads[ndims + idx_dim]: if pads[idx_dim] != pads[ndims + idx_dim]:
symmetric = False symmetric = False
break break
if symmetric: if symmetric:
return pads[:ndims], None return pads[:ndims], val_name
val_padded = val_name + '_padded' val_padded = val_name + '_padded'
prog.Op( prog.Op(
...@@ -315,13 +318,7 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE ...@@ -315,13 +318,7 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
return [0] * ndims, val_padded return [0] * ndims, val_padded
def _adaptive_pool(prog, def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
pool_type,
inputs,
outputs,
attrs,
value_infos,
name=''):
# I/O # I/O
val_x, = inputs val_x, = inputs
val_y, = outputs[:1] val_y, = outputs[:1]
...@@ -335,10 +332,6 @@ def _adaptive_pool(prog, ...@@ -335,10 +332,6 @@ def _adaptive_pool(prog,
# interpretation # interpretation
pool_size = attrs['output_size'] # required pool_size = attrs['output_size'] # required
output_shape = _shape_or_none(value_infos, val_y)
if output_shape is not None:
assert pool_size == output_shape[
2:], 'pool_size unmatches shape of Y' # NC...
poolnd = len(pool_size) poolnd = len(pool_size)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported' assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
...@@ -445,11 +438,9 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -445,11 +438,9 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
fluid_op = 'pool{}d'.format(poolnd) fluid_op = 'pool{}d'.format(poolnd)
strides = attrs.get('strides', [1] * poolnd) # optional strides = attrs.get('strides', [1] * poolnd) # optional
pads = attrs.get('pads', [0] * len(pool_size * 2)) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional
pads = attrs.get('pads', [0] * (poolnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x) var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
...@@ -506,17 +497,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name): ...@@ -506,17 +497,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
spatial_scale = attrs['spatial_scale'] # required spatial_scale = attrs['spatial_scale'] # required
pooled_height, pooled_width = attrs['pooled_shape'] # required pooled_height, pooled_width = attrs['pooled_shape'] # required
od_attrs = dict( od_attrs = dict(
spatial_scale=spatial_scale,
pooled_height=pooled_height, pooled_height=pooled_height,
pooled_width=pooled_width, pooled_width=pooled_width,
spatial_scale=spatial_scale,
) )
feature_attr = '' feature_attr = ''
is_max_pool = fluid_op == 'roi_pool' is_max_pool = fluid_op == 'roi_pool'
if 'sampling_ratio' in attrs: if 'sampling_ratio' in attrs: #
sampling_ratio = attrs['sampling_ratio'] sampling_ratio = attrs['sampling_ratio']
od_attrs['sampling_ratio'] = sampling_ratio od_attrs['sampling_ratio'] = sampling_ratio
feature_attr += ', sampling_ratio={}'.format(sampling_ratio) feature_attr += ', sampling_ratio={}'.format(sampling_ratio)
if 'output_channels' in attrs: if 'output_channels' in attrs: #
output_channels = attrs['output_channels'] output_channels = attrs['output_channels']
od_attrs['output_channels'] = output_channels od_attrs['output_channels'] = output_channels
feature_attr += ', output_channels={}'.format(output_channels) feature_attr += ', output_channels={}'.format(output_channels)
...@@ -560,36 +551,20 @@ def _zeros_like(prog, val_ref, val_out, value_infos): ...@@ -560,36 +551,20 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
) )
def AdaptiveAveragePool(prog, def AdaptiveAveragePool(prog, inputs, outputs, attrs, *args, name='', **kwargs):
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
""" """
aten::adaptive_avg_poolnd aten::adaptive_avg_poolnd
""" """
return _adaptive_pool( return _adaptive_pool(prog, 'avg', inputs, outputs, attrs, name=name)
prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
def AdaptiveMaxPool(prog, def AdaptiveMaxPool(prog, inputs, outputs, attrs, *args, name='', **kwargs):
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
""" """
aten::adaptive_max_poolnd aten::adaptive_max_poolnd
""" """
return _adaptive_pool( return _adaptive_pool(prog, 'max', inputs, outputs, attrs, name=name)
prog, 'max', inputs, outputs, attrs, value_infos, name=name)
def AveragePool(prog, def AveragePool(prog,
...@@ -734,9 +709,9 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -734,9 +709,9 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
var_output = _make_var_name(val_output) var_output = _make_var_name(val_output)
# interpretation # interpretation
dtype = attrs['to'] dtype = attrs['to'] # required
if not isinstance(dtype, np.dtype): if not isinstance(dtype, np.dtype): # additional: possible np.dtype
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] # required dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = _dtype_or_none(value_infos, val_output) output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype: if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output' assert dtype == output_dtype, 'dtype of to unmatches output'
...@@ -818,15 +793,16 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -818,15 +793,16 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype' assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
# dtype = np.dtype('float32') # force to float32 # dtype = np.dtype('float32') # HINT: force to float32
shape = attrs.get('shape', None) # additional, maybe var_name shape = attrs.get('shape', None) # additional, maybe var_name
if shape is None: if shape is None:
shape = _shape_or_none(value_infos, val_output) shape = _shape_or_none(value_infos, val_output)
if shape is None: if shape is None:
shape = list(value.shape) shape = list(value.shape)
_logger.warning( _logger.warning(
'shape of %s not inferred, using value as 1-D tensor may lead to fails', 'in (Constant -> %s): '
val_output) 'shape of %s not inferred, '
'using value as 1-D tensor may lead to fails', outputs, val_output)
# generation # generation
if value.size == 1: # scalar if value.size == 1: # scalar
...@@ -855,18 +831,27 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -855,18 +831,27 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
""" """
# I/O # I/O
val_input, = inputs val_shape, = inputs
val_output, = outputs
is_const_shape = 'const_value' in value_infos[val_input] shape = _const_weight_or_none(value_infos, val_shape)
if is_const_shape: if shape is None:
shape = _make_var_name(val_input) shape = _shape_or_none(value_infos, val_output)
else: assert shape is not None, (
shape = value_infos[val_input]['get_weight']() 'given shape is neither const value nor deductible from output, '
'this is not supported')
dtype = attrs['value'].dtype dtype = attrs['value'].dtype
attrs = attrs.copy() attrs = attrs.copy()
attrs.update(dict(shape=shape, dtype=dtype)) # pass var_name attrs.update(dict(shape=shape, dtype=dtype)) # pass var_name
Constant(prog, [], outputs, attrs, value_infos) prog.Op(
'',
'Constant',
[],
outputs, # val
attrs,
value_infos,
)
def Conv(prog, def Conv(prog,
...@@ -903,13 +888,11 @@ def Conv(prog, ...@@ -903,13 +888,11 @@ def Conv(prog,
num_out_channels = _shape(value_infos, val_w)[0] # OI... num_out_channels = _shape(value_infos, val_w)[0] # OI...
fluid_op = 'conv{}d'.format(convnd) fluid_op = 'conv{}d'.format(convnd)
num_groups = attrs.get('group', 1) # optional
strides = attrs.get('strides', [1] * convnd) # optional strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
dilations = attrs.get('dilations', [1] * convnd) # optional dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x) var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
...@@ -1014,13 +997,11 @@ def ConvTranspose(prog, ...@@ -1014,13 +997,11 @@ def ConvTranspose(prog,
num_out_channels = _shape(value_infos, val_w)[1] # IO... num_out_channels = _shape(value_infos, val_w)[1] # IO...
fluid_op = 'conv{}d_transpose'.format(convnd) fluid_op = 'conv{}d_transpose'.format(convnd)
num_groups = attrs.get('group', 1) # optional
strides = attrs.get('strides', [1] * convnd) # optional strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
dilations = attrs.get('dilations', [1] * convnd) # optional dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos)
var_x = _make_var_name(val_x) var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
...@@ -1090,7 +1071,7 @@ def ConvTranspose(prog, ...@@ -1090,7 +1071,7 @@ def ConvTranspose(prog,
prog.VarDesc(var_y) prog.VarDesc(var_y)
# should not appears # should not appear
#def Dropout( #def Dropout(
# prog, inputs, outputs, value_infos, # prog, inputs, outputs, value_infos,
# *args, **kwargs): # *args, **kwargs):
...@@ -1154,9 +1135,15 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1154,9 +1135,15 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
else: else:
val_beta = name + '_beta' # explicit variable val_beta = name + '_beta' # explicit variable
val_vm = name + '_vm' # explicit variable val_vm = name + '_vm' # explicit variable
if beta.is_integer():
vm_dtype = _dtype_or_none(value_infos, val_c) vm_dtype = _dtype_or_none(value_infos, val_c)
if vm_dtype is None: if vm_dtype is None:
vm_dtype = np.dtype('float32') vm_dtype = np.dtype('float32')
_logger.warning(
'in %s(%s -> Gemm -> %s): '
'beta seems to be an interger, '
'however dtype can not be inferred, '
'still use float32', name, inputs, outputs)
beta = np.dtype(vm_dtype).type(beta) beta = np.dtype(vm_dtype).type(beta)
prog.Op( prog.Op(
'', '',
...@@ -1429,13 +1416,15 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1429,13 +1416,15 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
var_reshaped = _make_var_name(val_reshaped) var_reshaped = _make_var_name(val_reshaped)
# interpretation # interpretation
fluid_op = 'reshape'
is_const_shape = 'const_value' in value_infos[val_shape]
var_shape = _make_var_name(val_shape) # for code var_shape = _make_var_name(val_shape) # for code
if is_const_shape: shape = _const_weight_or_none(value_infos, val_shape)
shape = value_infos[val_shape]['const_value'] # for desc is_const_shape = shape and 'const_value' in value_infos[val_shape]
else: if shape is None:
shape = value_infos[val_shape]['get_weight']() # for desc shape = _shape_or_none(value_infos, var_reshaped)
assert shape is not None, (
'given shape is neither const value nor deductible from output, '
'this is not supported')
fluid_op = 'reshape'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
...@@ -1457,7 +1446,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1457,7 +1446,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Cast', 'Cast',
[var_shape], [var_shape],
[var_shape_int32], # var [var_shape_int32], # var
dict(to=np.dtype('int32')), dict(to=np.dtype('int32')), # use np.dtype
value_infos=value_infos, value_infos=value_infos,
name=(name + '_cast'), name=(name + '_cast'),
) )
...@@ -1593,26 +1582,25 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1593,26 +1582,25 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
var_output = _make_var_name(val_output) var_output = _make_var_name(val_output)
# interpretation # interpretation
repeats = _const_weight_or_none(value_infos, val_repeats)
assert repeats is not None, 'only const repeats is supported'
fluid_op = 'expand' fluid_op = 'expand'
is_const_repeats = 'const_value' in value_infos[val_repeats]
if is_const_repeats:
code_repeats = _make_var_name(val_repeats) # for code
repeats = value_infos[val_repeats]['const_value'] # for desc
else:
repeats = value_infos[val_input]['get_weight']() # for desc
code_repeats = repeats # for code
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', expand_times={}' ', expand_times={}'
'{})'.format( '{})'
' # {} = {}'.format(
var_output, var_output,
fluid_op, fluid_op,
var_input, var_input,
# attrs # attrs
code_repeats, repeats,
name_attr, name_attr,
# comment
_make_var_name(val_repeats),
repeats,
)) ))
prog.VarDesc(var_output) prog.VarDesc(var_output)
prog.OpDesc( prog.OpDesc(
......
...@@ -6,11 +6,12 @@ Created on Fri Mar 22 12:17:19 2019 ...@@ -6,11 +6,12 @@ Created on Fri Mar 22 12:17:19 2019
@author: Macrobull @author: Macrobull
""" """
# import importlib, logging, os, sys import importlib, logging, os, sys
import importlib
import logging #import importlib
import os #import logging
import sys #import os
#import sys
def _flatten_dict(obj, out=None): def _flatten_dict(obj, out=None):
......
...@@ -8,9 +8,9 @@ Created on Sun Feb 24 20:44:43 2019 ...@@ -8,9 +8,9 @@ Created on Sun Feb 24 20:44:43 2019
from __future__ import division from __future__ import division
# import logging, os import logging, os
import logging #import logging
import os #import os
import numpy as np import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -215,10 +215,6 @@ class Program(object): ...@@ -215,10 +215,6 @@ class Program(object):
var_desc.persistable = persistable var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
# REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dummy_dtype) # required
if value_info and 'dtype' in value_info: if value_info and 'dtype' in value_info:
tensor_desc = var_desc.type.lod_tensor.tensor tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
...@@ -230,6 +226,9 @@ class Program(object): ...@@ -230,6 +226,9 @@ class Program(object):
not persistable) not persistable)
if remove_batch: if remove_batch:
tensor_desc.dims[0] = -1 tensor_desc.dims[0] = -1
else: # REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dummy_dtype) # required
self.var_descs.append(var_desc) self.var_descs.append(var_desc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册