提交 d6e4a4ba 编写于 作者: M Macrobull

add naive option

上级 9d147284
...@@ -19,12 +19,12 @@ def make_var_name(name): ...@@ -19,12 +19,12 @@ def make_var_name(name):
assert name assert name
if name[0].isdigit(): for s in ' \\|/:.-':
return 'var_' + name
for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
elif name[0].isdigit():
name = 'var_' + name
return name return name
......
...@@ -22,12 +22,12 @@ def make_var_name(name): ...@@ -22,12 +22,12 @@ def make_var_name(name):
assert name assert name
if name[0].isdigit(): for s in ' \\|/:.-':
return 'var_' + name
for s in ' \\|/:-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
elif name[0].isdigit():
name = 'var_' + name
return name return name
......
...@@ -64,7 +64,14 @@ parser.add_argument( ...@@ -64,7 +64,14 @@ parser.add_argument(
'-x', '-x',
action='store_false', action='store_false',
dest='pedantic', dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails', help='process non-standard ONNX ops, this may lead to failures',
)
parser.add_argument(
'--naive',
'-n',
action='store_true',
default=False,
help='bypass ONNX op optimizations, especially for training purpose',
) )
parser.add_argument( parser.add_argument(
'--skip-version-conversion', '--skip-version-conversion',
......
...@@ -18,6 +18,8 @@ from __future__ import unicode_literals ...@@ -18,6 +18,8 @@ from __future__ import unicode_literals
import logging, shutil, zipfile import logging, shutil, zipfile
logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'main', 'main',
] ]
...@@ -45,6 +47,7 @@ def main(**kwargs): ...@@ -45,6 +47,7 @@ def main(**kwargs):
model_basename = DEFAULT_MODEL_MODULE + '.py' model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC model_func_name = DEFAULT_MODEL_FUNC
onnx_opset_pedantic = kwargs.pop('pedantic', True) onnx_opset_pedantic = kwargs.pop('pedantic', True)
onnx_skip_optimization = kwargs.pop('naive', False)
skip_version_conversion = kwargs.pop('skip_version_conversion', False) skip_version_conversion = kwargs.pop('skip_version_conversion', False)
onnx_opset_version = None if skip_version_conversion else DEFAULT_ONNX_OPSET_VERSION onnx_opset_version = None if skip_version_conversion else DEFAULT_ONNX_OPSET_VERSION
...@@ -55,6 +58,7 @@ def main(**kwargs): ...@@ -55,6 +58,7 @@ def main(**kwargs):
model_func_name=model_func_name, model_func_name=model_func_name,
onnx_opset_version=onnx_opset_version, onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=onnx_opset_pedantic, onnx_opset_pedantic=onnx_opset_pedantic,
onnx_skip_optimization=onnx_skip_optimization,
**kwargs) **kwargs)
# validate # validate
...@@ -65,7 +69,7 @@ def main(**kwargs): ...@@ -65,7 +69,7 @@ def main(**kwargs):
if golden_data_filename or save_inference_model: if golden_data_filename or save_inference_model:
from .validation import validate from .validation import validate
if save_inference_model: if infer_inputs:
inference_input_names = infer_inputs.split(',') inference_input_names = infer_inputs.split(',')
else: else:
inference_input_names = None inference_input_names = None
......
...@@ -24,12 +24,12 @@ def make_var_name(name): ...@@ -24,12 +24,12 @@ def make_var_name(name):
if name == '': if name == '':
return '' return ''
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:.-': for s in ' \\|/:.-':
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
elif name[0].isdigit():
name = 'var_' + name
return name return name
...@@ -40,6 +40,7 @@ def convert(onnx_model_filename, ...@@ -40,6 +40,7 @@ def convert(onnx_model_filename,
embed_params=False, embed_params=False,
onnx_opset_version=None, onnx_opset_version=None,
onnx_opset_pedantic=True, onnx_opset_pedantic=True,
onnx_skip_optimization=False,
debug=False, debug=False,
**kwargs): **kwargs):
""" """
...@@ -61,10 +62,10 @@ def convert(onnx_model_filename, ...@@ -61,10 +62,10 @@ def convert(onnx_model_filename,
from .onnx_utils import DEFAULT_OP_DOMAIN from .onnx_utils import DEFAULT_OP_DOMAIN
from .onnx_utils import graph_ops, graph_weights from .onnx_utils import graph_ops, graph_weights
from .onnx_utils import inferred_model_value_info from .onnx_utils import inferred_model_value_info
from .onnx_utils import polish_model from .onnx_utils import polish_model, optimize_model_strip_initializer
from .writer import Program, Writer from .writer import Program, Writer
logger = logging.getLogger('convert') logger = logging.getLogger('onnx2fluid')
# prepare onnx model # prepare onnx model
logger.info('loading model: %s ...', onnx_model_filename) logger.info('loading model: %s ...', onnx_model_filename)
...@@ -90,6 +91,10 @@ def convert(onnx_model_filename, ...@@ -90,6 +91,10 @@ def convert(onnx_model_filename,
# onnx model optimization # onnx model optimization
logger.info('model has %d ops', len(onnx_model.graph.node)) logger.info('model has %d ops', len(onnx_model.graph.node))
if onnx_skip_optimization:
logger.info('stripping model ...')
onnx_model = optimize_model_strip_initializer(onnx_model)
else:
logger.info('optimizing model ...') logger.info('optimizing model ...')
onnx_model = polish_model(onnx_model, checking=onnx_opset_pedantic) onnx_model = polish_model(onnx_model, checking=onnx_opset_pedantic)
...@@ -123,7 +128,7 @@ def convert(onnx_model_filename, ...@@ -123,7 +128,7 @@ def convert(onnx_model_filename,
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
var_name = make_var_name(name) var_name = make_var_name(name)
value_info = value_infos[var_name] value_info = value_infos[var_name]
value_info['lod'] = [0] value_info['lod'] = []
value_info['embedded_as'] = [] value_info['embedded_as'] = []
value_info['get_weight'] = (lambda w: lambda: w.tolist())( value_info['get_weight'] = (lambda w: lambda: w.tolist())(
weight) # lazy getter weight) # lazy getter
...@@ -306,7 +311,14 @@ def main(): ...@@ -306,7 +311,14 @@ def main():
'-x', '-x',
action='store_false', action='store_false',
dest='pedantic', dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails', help='process non-standard ONNX ops, this may lead to failures',
)
parser.add_argument(
'--naive',
'-n',
action='store_true',
default=False,
help='bypass ONNX op optimizations, especially for training purpose',
) )
parser.add_argument( parser.add_argument(
'--skip-version-conversion', '--skip-version-conversion',
...@@ -329,13 +341,15 @@ def main(): ...@@ -329,13 +341,15 @@ def main():
if save_dir else basepath) + shutil.os.sep if save_dir else basepath) + shutil.os.sep
embed_params = args.embed_params embed_params = args.embed_params
pedantic = args.pedantic pedantic = args.pedantic
skip_version_conversion = args.skip_version_conversion skip_optimization = args.naive
onnx_opset_version = None if args.skip_version_conversion else DEFAULT_ONNX_OPSET_VERSION
convert(model_filename, convert(model_filename,
save_dir, save_dir,
embed_params=embed_params, embed_params=embed_params,
onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=pedantic, onnx_opset_pedantic=pedantic,
onnx_skip_version_conversion=skip_version_conversion, onnx_skip_optimization=skip_optimization,
debug=debug) debug=debug)
......
...@@ -356,16 +356,16 @@ def polish_model(model, internals=True, extras=True, checking=True): ...@@ -356,16 +356,16 @@ def polish_model(model, internals=True, extras=True, checking=True):
def polish_and_save(model_filename, def polish_and_save(model_filename,
save_filename='',
suffix='.polished', suffix='.polished',
save_filename=None,
*args, *args,
**kwargs): **kwargs):
""" """
run polish_model and save run polish_model and save
""" """
if save_filename is None: save_filename = save_filename or model_filename.replace(
save_filename = model_filename.replace('.onnx', suffix + '.onnx') '.onnx', suffix + '.onnx')
model = onnx.load(model_filename) model = onnx.load(model_filename)
model = polish_model(model, *args, **kwargs) model = polish_model(model, *args, **kwargs)
......
...@@ -18,7 +18,8 @@ import numpy as _np ...@@ -18,7 +18,8 @@ import numpy as _np
from collections import OrderedDict as _dict from collections import OrderedDict as _dict
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
_logger = _logging.getLogger(__name__) # _logger = _logging.getLogger(__name__)
_logger = _logging.getLogger('onnx2fluid')
ONNX_INT_MAX = 2**63 - 1 ONNX_INT_MAX = 2**63 - 1
FLUID_INT_MAX = 2**31 - 1 # FLUID_INT_MAX = 2**31 - 1 #
...@@ -58,8 +59,8 @@ DEFAULT_OP_MAPPING = { ...@@ -58,8 +59,8 @@ DEFAULT_OP_MAPPING = {
'Ceil': ['ceil', ['X'], ['Out']], 'Ceil': ['ceil', ['X'], ['Out']],
'Clip': 'Clip':
['clip', ['X'], ['Out'], dict(), dict( ['clip', ['X'], ['Out'], dict(), dict(
min=(_np.array([255, 255, 127, 255], dtype=_np.uint8).view(_np.float32)), min=(_np.asarray([255, 255, 127, 255], dtype=_np.uint8).view(_np.float32)),
max=(_np.array([255, 255, 127, 127], dtype=_np.uint8).view(_np.float32)), max=(_np.asarray([255, 255, 127, 127], dtype=_np.uint8).view(_np.float32)),
)], )],
'Cos': ['cos', ['X'], ['Out']], 'Cos': ['cos', ['X'], ['Out']],
'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)], 'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)],
...@@ -449,7 +450,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name): ...@@ -449,7 +450,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name):
# I/O # I/O
var_x, = inputs var_x, = inputs
var_y, var_indices, = (outputs + [''] * 1)[:2] var_y, var_indices, = (outputs + [''] * 1)[:2]
assert name and var_x and var_y assert name and all(inputs) and var_y
# interpretation # interpretation
assert attrs.get( assert attrs.get(
...@@ -512,7 +513,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, name): ...@@ -512,7 +513,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, name):
# I/O # I/O
var_x, var_rois, = inputs var_x, var_rois, = inputs
var_y, = outputs var_y, = outputs
assert name and var_x and var_rois and var_y assert name and all(inputs) and all(outputs)
# interpretation # interpretation
spatial_scale = attrs['spatial_scale'] # required spatial_scale = attrs['spatial_scale'] # required
...@@ -565,7 +566,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): ...@@ -565,7 +566,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
# I/O # I/O
var_x, var_scales, = inputs var_x, var_scales, = inputs
var_y, = outputs var_y, = outputs
assert var_x and var_scales and var_y assert all(inputs) and all(outputs)
# interpretation # interpretation
# output shape # output shape
...@@ -701,7 +702,7 @@ def BatchNormalization(prog, ...@@ -701,7 +702,7 @@ def BatchNormalization(prog,
var_x, var_scale, var_b, var_mean, var_var, = inputs var_x, var_scale, var_b, var_mean, var_var, = inputs
var_y, var_mean_, var_var_, var_saved_mean, var_saved_variance, = ( var_y, var_mean_, var_var_, var_saved_mean, var_saved_variance, = (
outputs + [''] * 4)[:5] outputs + [''] * 4)[:5]
assert var_x and var_scale and var_b and var_mean and var_var and var_y assert all(inputs) and var_y
assert var_saved_mean or name assert var_saved_mean or name
assert var_saved_variance or name assert var_saved_variance or name
var_saved_mean = var_saved_mean or (name + '.saved_mean') # dummy output var_saved_mean = var_saved_mean or (name + '.saved_mean') # dummy output
...@@ -879,7 +880,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -879,7 +880,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
_logger.warning( _logger.warning(
'in op (Constant -> %s): ' 'in op (Constant -> %s): '
'attribute "shape" of %s not inferred, ' 'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails', outputs, var_output) 'using value as 1-D tensor may lead to failures', outputs,
var_output)
# generation # generation
if not shape or value.size == 1: # scalar or 1-size if not shape or value.size == 1: # scalar or 1-size
...@@ -929,7 +931,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -929,7 +931,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'given shape is neither const value nor deductible from output, ' 'given shape is neither const value nor deductible from output, '
'this is not supported') 'this is not supported')
attrs = attrs.copy() attrs = attrs.copy()
attrs.setdefault('value', _np.array(0, dtype=_np.float32)) attrs.setdefault('value', _np.asarray(0, dtype=_np.float32))
attrs.update({'shape': shape}) # pass const attrs.update({'shape': shape}) # pass const
prog.Code('# shape: {} = {} # const as literal'.format(var_shape, shape)) prog.Code('# shape: {} = {} # const as literal'.format(var_shape, shape))
...@@ -959,7 +961,7 @@ def Conv(prog, ...@@ -959,7 +961,7 @@ def Conv(prog,
# I/O # I/O
var_x, var_w, var_b, = (inputs + [''] * 1)[:3] var_x, var_w, var_b, = (inputs + [''] * 1)[:3]
var_y, = outputs var_y, = outputs
assert name and var_x and var_w and var_y assert name and var_x and var_w and all(outputs)
# interpretation # interpretation
assert attrs.get( assert attrs.get(
...@@ -1066,7 +1068,7 @@ def ConvTranspose(prog, ...@@ -1066,7 +1068,7 @@ def ConvTranspose(prog,
# I/O # I/O
var_x, var_w, var_b, = (inputs + [''] * 1)[:3] var_x, var_w, var_b, = (inputs + [''] * 1)[:3]
var_y, = outputs var_y, = outputs
assert name and var_x and var_w and var_y assert name and var_x and var_w and all(outputs)
# interpretation # interpretation
assert attrs.get( assert attrs.get(
...@@ -1174,7 +1176,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1174,7 +1176,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
# due to fluid fc don't support transposed weight, we use matmul + ew_add # due to fluid fc don't support transposed weight, we use matmul + ew_add
var_a, var_b, var_c, = inputs var_a, var_b, var_c, = inputs
var_y, = outputs var_y, = outputs
assert name and var_a and var_b and var_c and var_y assert name and all(inputs) and all(outputs)
alpha = attrs.get('alpha', 1.) # optional alpha = attrs.get('alpha', 1.) # optional
beta = attrs.get('beta', 1.) # optional beta = attrs.get('beta', 1.) # optional
...@@ -1794,7 +1796,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1794,7 +1796,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
mode) mode)
fluid_op = 'pad' fluid_op = 'pad'
pad2d_attr = '' pad2d_attr = ''
paddings = _np.array(pads).reshape( paddings = _np.asarray(pads).reshape(
(-1, 2)).transpose().flatten().tolist() # SSEE -> SESE (-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
od_attrs['paddings'] = paddings od_attrs['paddings'] = paddings
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
...@@ -1838,7 +1840,7 @@ def PRelu(prog, ...@@ -1838,7 +1840,7 @@ def PRelu(prog,
# I/O # I/O
var_x, var_slope, = inputs var_x, var_slope, = inputs
var_y, = outputs var_y, = outputs
assert name and var_x and var_slope and var_y assert name and all(inputs) and all(outputs)
# interpretation # interpretation
mode = 'channel' mode = 'channel'
...@@ -1904,7 +1906,7 @@ def Reshape(prog, inputs, outputs, attrs_, value_infos, name, *args, **kwargs): ...@@ -1904,7 +1906,7 @@ def Reshape(prog, inputs, outputs, attrs_, value_infos, name, *args, **kwargs):
# I/O # I/O
var_data, var_shape, = inputs var_data, var_shape, = inputs
var_reshaped, = outputs var_reshaped, = outputs
assert name and var_data and var_shape and var_reshaped assert name and all(inputs) and all(outputs)
# interpretation # interpretation
shape = _const_weight_or_none(value_infos, var_shape) shape = _const_weight_or_none(value_infos, var_shape)
...@@ -2015,7 +2017,7 @@ def Shape(prog, inputs, outputs, attrs_, name, **kwargs): ...@@ -2015,7 +2017,7 @@ def Shape(prog, inputs, outputs, attrs_, name, **kwargs):
# I/O # I/O
var_data, = inputs var_data, = inputs
var_shape, = outputs var_shape, = outputs
assert name and var_data and var_shape assert name and all(inputs) and all(outputs)
# interpretation # interpretation
fluid_op = 'shape' fluid_op = 'shape'
...@@ -2189,7 +2191,7 @@ def Tile(prog, inputs, outputs, attrs_, value_infos, name='', *args, **kwargs): ...@@ -2189,7 +2191,7 @@ def Tile(prog, inputs, outputs, attrs_, value_infos, name='', *args, **kwargs):
# I/O # I/O
var_input, var_repeats, = inputs var_input, var_repeats, = inputs
var_output, = outputs var_output, = outputs
assert var_input and var_repeats and var_output assert all(inputs) and all(outputs)
# interpretation # interpretation
repeats = _const_weight_or_none(value_infos, var_repeats) repeats = _const_weight_or_none(value_infos, var_repeats)
...@@ -2227,7 +2229,7 @@ def Transpose(prog, inputs, outputs, attrs, name, *args, **kwargs): ...@@ -2227,7 +2229,7 @@ def Transpose(prog, inputs, outputs, attrs, name, *args, **kwargs):
# I/O # I/O
var_data, = inputs var_data, = inputs
var_transposed, = outputs var_transposed, = outputs
assert name and var_data and var_transposed assert name and all(inputs) and all(outputs)
# interpretation # interpretation
fluid_op = 'transpose' fluid_op = 'transpose'
......
...@@ -138,10 +138,10 @@ def export_onnx_with_validation( ...@@ -138,10 +138,10 @@ def export_onnx_with_validation(
outputs = torch.onnx.export(model, outputs = torch.onnx.export(model,
torch_inputs, torch_inputs,
export_basepath + '.onnx', export_basepath + '.onnx',
input_names=(None if input_names is None else input_names=(input_names
flatten_list(input_names)), and flatten_list(input_names)),
output_names=(None if output_names is None else output_names=(output_names
flatten_list(output_names)), and flatten_list(output_names)),
*args, *args,
**kwargs) **kwargs)
if outputs is None: # WORKAROUND: for torch.onnx if outputs is None: # WORKAROUND: for torch.onnx
......
...@@ -90,7 +90,7 @@ def validate(fluid_model_filename, ...@@ -90,7 +90,7 @@ def validate(fluid_model_filename,
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
logger = logging.getLogger('validate') logger = logging.getLogger('onnx2fluid')
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -126,6 +126,7 @@ def validate(fluid_model_filename, ...@@ -126,6 +126,7 @@ def validate(fluid_model_filename,
logger.info('import passed') logger.info('import passed')
prog = fluid.default_main_program() prog = fluid.default_main_program()
prog = prog.clone(for_test=True) # force inference mode
fluid.io.load_persistables(executor=exe, fluid.io.load_persistables(executor=exe,
dirname=fluid_model_dir, dirname=fluid_model_dir,
main_program=prog) main_program=prog)
...@@ -160,8 +161,7 @@ def validate(fluid_model_filename, ...@@ -160,8 +161,7 @@ def validate(fluid_model_filename,
logger.info('with %d inputs and %d outputs', len(input_data), logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data)) len(output_data))
elif save_inference_model: elif save_inference_model:
assert inference_input_names is not None, ( assert inference_input_names, 'input names required for type-shape inference'
'input names required for type-shape inference')
input_names = inference_input_names input_names = inference_input_names
logger.info('using input names: %s', ', '.join(input_names)) logger.info('using input names: %s', ', '.join(input_names))
...@@ -185,6 +185,7 @@ def validate(fluid_model_filename, ...@@ -185,6 +185,7 @@ def validate(fluid_model_filename,
# execute # execute
outputs = exe.run(prog, feed=input_data, outputs = exe.run(prog, feed=input_data,
fetch_list=out_names) # out_names can be vars fetch_list=out_names) # out_names can be vars
exe.close()
logger.info('execution passed') logger.info('execution passed')
# validate # validate
...@@ -264,7 +265,7 @@ def main(): ...@@ -264,7 +265,7 @@ def main():
atol, rtol = args.atol, args.rtol atol, rtol = args.atol, args.rtol
save_inference_model = args.infer_inputs is not None save_inference_model = args.infer_inputs is not None
inference_input_names = args.infer_inputs.split( inference_input_names = args.infer_inputs.split(
',') if args.infer_inputs else None ',') if save_inference_model else None
validate(fluid_model_filename, validate(fluid_model_filename,
golden_data_filename=golden_data_filename, golden_data_filename=golden_data_filename,
......
...@@ -372,7 +372,7 @@ class Writer(object): ...@@ -372,7 +372,7 @@ class Writer(object):
prog.Code('# input {}'.format(name)) prog.Code('# input {}'.format(name))
prog.Code(( prog.Code((
'{} = layers.data(name={}, shape={}, dtype={}, ' '{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True 'append_batch_size={}, lod_level=1)' # , stop_gradient=True
).format( ).format(
name, name,
repr(name), repr(name),
...@@ -427,20 +427,28 @@ class Writer(object): ...@@ -427,20 +427,28 @@ class Writer(object):
assert lod is None or isinstance(lod, assert lod is None or isinstance(lod,
list), 'lod should be None or list' list), 'lod should be None or list'
if lod is None: lod = lod or []
lod = [0]
tensor_desc = framework_pb2.VarType.TensorDesc() tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype) tensor_desc.data_type = Program.Dtype(weight.dtype)
tensor_desc.dims.extend(weight.shape) tensor_desc.dims.extend(weight.shape)
fp = open(filename, 'wb') fp = open(filename, 'wb')
np.array([0], dtype=np.int32).tofile(fp) # version
np.array(lod, dtype=np.int64).tofile(fp) # LOD level # lod_tensor.cc: SerializeToStream
np.array([0], dtype=np.int32).tofile(fp) # tensor version np.asarray([0], dtype=np.uint32).tofile(fp) # version
np.array([tensor_desc.ByteSize()], dtype=np.int32).tofile(fp) np.asarray([len(lod)], dtype=np.int64).tofile(fp) # LOD levels
for level in lod:
np.asarray([len(level)], dtype=np.int64).tofile(fp) # level size
np.asarray(level, dtype=np.uint64).tofile(fp) # LOD: size_t
# tensor_util.cc: TensorToStream
np.asarray([0], dtype=np.uint32).tofile(fp) # tensor version
np.asarray([tensor_desc.ByteSize()], dtype=np.int32).tofile(fp)
fp.write(tensor_desc.SerializeToString()) fp.write(tensor_desc.SerializeToString())
weight.tofile(fp) weight.tofile(fp)
fp.close() fp.close()
@staticmethod @staticmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册