提交 021e765a 编写于 作者: 李寅

Merge branch 'tuning-default' into 'master'

Fix converter bug and add winograd parameters selection.

See merge request !604
...@@ -112,7 +112,7 @@ def main(unused_args): ...@@ -112,7 +112,7 @@ def main(unused_args):
option = cvt.ConverterOption(FLAGS.transformers.split(',')) option = cvt.ConverterOption(FLAGS.transformers.split(','))
else: else:
option = cvt.ConverterOption() option = cvt.ConverterOption()
option.winograd_enabled = bool(FLAGS.winograd) option.winograd = FLAGS.winograd
input_node_names = FLAGS.input_node.split(',') input_node_names = FLAGS.input_node.split(',')
input_node_shapes = FLAGS.input_shape.split(':') input_node_shapes = FLAGS.input_shape.split(':')
...@@ -146,6 +146,17 @@ def main(unused_args): ...@@ -146,6 +146,17 @@ def main(unused_args):
print("Transform model to one that can better run on device") print("Transform model to one that can better run on device")
if FLAGS.runtime == 'cpu+gpu': if FLAGS.runtime == 'cpu+gpu':
cpu_graph_def = copy.deepcopy(output_graph_def) cpu_graph_def = copy.deepcopy(output_graph_def)
option.device = cvt.DeviceType.GPU.value
option.data_type = parse_data_type(
FLAGS.data_type, cvt.DeviceType.GPU.value)
mace_gpu_transformer = transformer.Transformer(
option, output_graph_def)
output_graph_def = mace_gpu_transformer.run()
print "start optimize gpu memory."
memory_optimizer.optimize_gpu_memory(output_graph_def)
print "GPU memory optimization done."
option.device = cvt.DeviceType.CPU.value option.device = cvt.DeviceType.CPU.value
option.data_type = parse_data_type( option.data_type = parse_data_type(
FLAGS.data_type, cvt.DeviceType.CPU.value) FLAGS.data_type, cvt.DeviceType.CPU.value)
...@@ -157,17 +168,6 @@ def main(unused_args): ...@@ -157,17 +168,6 @@ def main(unused_args):
memory_optimizer.optimize_cpu_memory(cpu_graph_def) memory_optimizer.optimize_cpu_memory(cpu_graph_def)
print "CPU memory optimization done." print "CPU memory optimization done."
option.device = cvt.DeviceType.GPU.value
option.data_type = parse_data_type(
FLAGS.data_type, cvt.DeviceType.GPU.value)
option.enable_transpose_filters()
mace_gpu_transformer = transformer.Transformer(
option, output_graph_def)
output_gpu_graph_def = mace_gpu_transformer.run()
print "start optimize gpu memory."
memory_optimizer.optimize_gpu_memory(output_gpu_graph_def)
print "GPU memory optimization done."
print "Merge cpu and gpu ops together" print "Merge cpu and gpu ops together"
output_graph_def.op.extend(cpu_graph_def.op) output_graph_def.op.extend(cpu_graph_def.op)
output_graph_def.mem_arena.mem_block.extend( output_graph_def.mem_arena.mem_block.extend(
...@@ -261,11 +261,9 @@ def parse_args(): ...@@ -261,11 +261,9 @@ def parse_args():
help="model tag for generated function and namespace") help="model tag for generated function and namespace")
parser.add_argument( parser.add_argument(
"--winograd", "--winograd",
type=str2bool, type=int,
nargs='?', default=0,
const=False, help="Which version of winograd convolution to use. [2 | 4]")
default=False,
help="open winograd convolution or not")
parser.add_argument( parser.add_argument(
"--dsp_mode", type=int, default=0, help="dsp run mode, defalut=0") "--dsp_mode", type=int, default=0, help="dsp run mode, defalut=0")
parser.add_argument( parser.add_argument(
......
...@@ -158,6 +158,7 @@ class MaceKeyword(object): ...@@ -158,6 +158,7 @@ class MaceKeyword(object):
mace_shrink_axis_mask_str = 'shrink_axis_mask' mace_shrink_axis_mask_str = 'shrink_axis_mask'
mace_transpose_a_str = 'transpose_a' mace_transpose_a_str = 'transpose_a'
mace_transpose_b_str = 'transpose_b' mace_transpose_b_str = 'transpose_b'
mace_op_data_type_str = 'T'
class TransformerRule(Enum): class TransformerRule(Enum):
...@@ -182,6 +183,7 @@ class TransformerRule(Enum): ...@@ -182,6 +183,7 @@ class TransformerRule(Enum):
SORT_BY_EXECUTION = 19 SORT_BY_EXECUTION = 19
ADD_IN_OUT_TENSOR_INFO = 20 ADD_IN_OUT_TENSOR_INFO = 20
ADD_MACE_INPUT_AND_OUTPUT_NODES = 21 ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
UPDATE_FLOAT_OP_DATA_TYPE = 22
class ConverterInterface(object): class ConverterInterface(object):
...@@ -226,7 +228,7 @@ class ConverterOption(object): ...@@ -226,7 +228,7 @@ class ConverterOption(object):
self._output_nodes = {} self._output_nodes = {}
self._data_type = mace_pb2.DT_FLOAT self._data_type = mace_pb2.DT_FLOAT
self._device = DeviceType.CPU.value self._device = DeviceType.CPU.value
self._winograd_enabled = False self._winograd = 0
if transformers: if transformers:
self._transformer_option = [TransformerRule[transformer] self._transformer_option = [TransformerRule[transformer]
for transformer in transformers] for transformer in transformers]
...@@ -251,6 +253,7 @@ class ConverterOption(object): ...@@ -251,6 +253,7 @@ class ConverterOption(object):
TransformerRule.RESHAPE_FC_WEIGHT, TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSFORM_BUFFER_IMAGE, TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE, TransformerRule.ADD_DEVICE,
TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES, TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES,
TransformerRule.SORT_BY_EXECUTION, TransformerRule.SORT_BY_EXECUTION,
] ]
...@@ -272,8 +275,8 @@ class ConverterOption(object): ...@@ -272,8 +275,8 @@ class ConverterOption(object):
return self._device return self._device
@property @property
def winograd_enabled(self): def winograd(self):
return self._winograd_enabled return self._winograd
@property @property
def transformer_option(self): def transformer_option(self):
...@@ -303,9 +306,9 @@ class ConverterOption(object): ...@@ -303,9 +306,9 @@ class ConverterOption(object):
def device(self, device): def device(self, device):
self._device = device self._device = device
@winograd_enabled.setter @winograd.setter
def winograd_enabled(self, winograd_enabled): def winograd(self, winograd):
self._winograd_enabled = winograd_enabled self._winograd = winograd
def disable_transpose_filters(self): def disable_transpose_filters(self):
if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option: if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
......
...@@ -31,7 +31,6 @@ from mace.python.tools.converter_tool.base_converter import TransformerRule ...@@ -31,7 +31,6 @@ from mace.python.tools.converter_tool.base_converter import TransformerRule
from mace.python.tools.convert_util import mace_check from mace.python.tools.convert_util import mace_check
OPENCL_IMAGE_MAX_SIZE = 16384 OPENCL_IMAGE_MAX_SIZE = 16384
DEFAULT_GPU_WINO_BLK_SIZE = 4
class OpenCLBufferType(enum.Enum): class OpenCLBufferType(enum.Enum):
...@@ -53,6 +52,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -53,6 +52,8 @@ class Transformer(base_converter.ConverterInterface):
""" """
def __init__(self, option, model): def __init__(self, option, model):
# Dependencies
# (TRANSFORM_MATMUL_TO_FC, TRANSFORM_GLOBAL_CONV_TO_FC) -> RESHAPE_FC_WEIGHT # noqa
self._registered_transformers = { self._registered_transformers = {
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING: TransformerRule.TRANSFORM_GLOBAL_POOLING:
...@@ -83,6 +84,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -83,6 +84,8 @@ class Transformer(base_converter.ConverterInterface):
self.transform_buffer_image, self.transform_buffer_image,
TransformerRule.ADD_DEVICE: TransformerRule.ADD_DEVICE:
self.add_device, self.add_device,
TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE:
self.update_float_op_data_type,
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES: TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES:
self.add_mace_input_and_output_nodes, self.add_mace_input_and_output_nodes,
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
...@@ -90,7 +93,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -90,7 +93,7 @@ class Transformer(base_converter.ConverterInterface):
self._option = option self._option = option
self._model = model self._model = model
self._gpu_wino_blk = DEFAULT_GPU_WINO_BLK_SIZE self._gpu_wino_blk = self._option.winograd
self._ops = {} self._ops = {}
self._consts = {} self._consts = {}
...@@ -442,7 +445,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -442,7 +445,7 @@ class Transformer(base_converter.ConverterInterface):
return filter_height, filter_width, in_channels, out_channels return filter_height, filter_width, in_channels, out_channels
def check_if_gpu_use_winograd_conv(self, op): def check_if_gpu_use_winograd_conv(self, op):
if not self._option.winograd_enabled: if not self._option.winograd:
return False return False
if op.type != MaceOp.Conv2D.name: if op.type != MaceOp.Conv2D.name:
return False return False
...@@ -464,7 +467,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -464,7 +467,6 @@ class Transformer(base_converter.ConverterInterface):
if filter_height != 3 or filter_width != 3 or strides[0] > 1 \ if filter_height != 3 or filter_width != 3 or strides[0] > 1 \
or strides[1] > 1 or dilations[0] > 1 or dilations[1] > 1: or strides[1] > 1 or dilations[0] > 1 or dilations[1] > 1:
return False return False
self._gpu_wino_blk = DEFAULT_GPU_WINO_BLK_SIZE
block_size = self._gpu_wino_blk block_size = self._gpu_wino_blk
blk_sqr = (block_size + 2) * (block_size + 2) blk_sqr = (block_size + 2) * (block_size + 2)
width =\ width =\
...@@ -479,9 +481,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -479,9 +481,9 @@ class Transformer(base_converter.ConverterInterface):
width = \ width = \
batch * ((out_height + block_size - 1) / block_size) * \ batch * ((out_height + block_size - 1) / block_size) * \
((out_width + block_size - 1) / block_size) ((out_width + block_size - 1) / block_size)
return (blk_sqr * in_channels <= OPENCL_IMAGE_MAX_SIZE) and \ return (blk_sqr * in_channels < OPENCL_IMAGE_MAX_SIZE) and \
(blk_sqr * out_channels <= OPENCL_IMAGE_MAX_SIZE) and \ (blk_sqr * out_channels < OPENCL_IMAGE_MAX_SIZE) and \
(width <= OPENCL_IMAGE_MAX_SIZE) (width < OPENCL_IMAGE_MAX_SIZE)
def transform_gpu_winograd(self): def transform_gpu_winograd(self):
"""Only gpu needs winograd transform.""" """Only gpu needs winograd transform."""
...@@ -577,17 +579,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -577,17 +579,6 @@ class Transformer(base_converter.ConverterInterface):
blk_size_arg.i = block_size blk_size_arg.i = block_size
ConverterUtil.add_data_format_arg(iwt_op, data_format) ConverterUtil.add_data_format_arg(iwt_op, data_format)
filter_data = np.array(filter.float_data).reshape(
filter.dims)
weight_tensor_value = filter_data
if filter_format == FilterFormat.HWIO:
weight_tensor_value = filter_data.transpose(3, 2, 0, 1)
elif filter_format == FilterFormat.HWOI:
weight_tensor_value = filter_data.transpose(2, 3, 0, 1)
filter.float_data[:] = weight_tensor_value.flat[:]
filter.dims[:] = weight_tensor_value.shape[:]
self.safe_remove_node(op, iwt_op) self.safe_remove_node(op, iwt_op)
return False return False
...@@ -608,12 +599,13 @@ class Transformer(base_converter.ConverterInterface): ...@@ -608,12 +599,13 @@ class Transformer(base_converter.ConverterInterface):
def fold_biasadd(self): def fold_biasadd(self):
net = self._model net = self._model
for op in net.op: for op in net.op:
if ((op.type == MaceOp.Conv2D.name if (((op.type == MaceOp.Conv2D.name
or op.type == MaceOp.Deconv2D.name or op.type == MaceOp.Deconv2D.name
or op.type == MaceOp.DepthwiseConv2d.name or op.type == MaceOp.DepthwiseConv2d.name
or op.type == MaceOp.FullyConnected.name or op.type == MaceOp.FullyConnected.name)
or op.type == MaceOp.WinogradInverseTransform.name) and len(op.input) == 2)
and len(op.input) == 2) \ or (op.type == MaceOp.WinogradInverseTransform.name
and len(op.input) == 1)) \
and len(self._consumers.get(op.output[0], [])) == 1: and len(self._consumers.get(op.output[0], [])) == 1:
consumer_op = self._consumers[op.output[0]][0] consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.BiasAdd.name: if consumer_op.type == MaceOp.BiasAdd.name:
...@@ -893,25 +885,24 @@ class Transformer(base_converter.ConverterInterface): ...@@ -893,25 +885,24 @@ class Transformer(base_converter.ConverterInterface):
if op.type == MaceOp.Conv2D.name \ if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name \ or op.type == MaceOp.Deconv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name: or op.type == MaceOp.DepthwiseConv2d.name:
if ConverterUtil.get_arg( filter = self._consts[op.input[1]]
op, MaceKeyword.mace_winograd_filter_transformed) \ filter_data = np.array(filter.float_data).reshape(
is None: filter.dims)
filter = self._consts[op.input[1]] if op.type == MaceOp.Deconv2D.name:
filter_data = np.array(filter.float_data).reshape( filter_data = filter_data.transpose(2, 3, 0, 1)
filter.dims) else:
if op.type == MaceOp.Deconv2D.name: filter_data = filter_data.transpose(3, 2, 0, 1)
filter_data = filter_data.transpose(2, 3, 0, 1) filter.float_data[:] = filter_data.flat
else: filter.dims[:] = filter_data.shape
filter_data = filter_data.transpose(3, 2, 0, 1) if (op.type == MaceOp.MatMul.name and
filter.float_data[:] = filter_data.flat ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None): # noqa
filter.dims[:] = filter_data.shape filter = self._consts[op.input[0]]
if op.type == MaceOp.FullyConnected.name: filter_data = np.array(filter.float_data).reshape(
weight = self._consts[op.input[1]] filter.dims)
weight_data = np.array(weight.float_data).reshape( filter_data = filter_data.transpose(3, 2, 0, 1)
weight.dims) filter.float_data[:] = filter_data.flat
weight_data = weight_data.transpose(1, 0) filter.dims[:] = filter_data.shape
weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape
self.set_filter_format(FilterFormat.OIHW) self.set_filter_format(FilterFormat.OIHW)
return False return False
...@@ -1104,6 +1095,11 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1104,6 +1095,11 @@ class Transformer(base_converter.ConverterInterface):
weight = self._consts[op.input[1]] weight = self._consts[op.input[1]]
if len(weight.dims) == 2: if len(weight.dims) == 2:
op.type = MaceOp.FullyConnected.name op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight_data = weight_data.transpose(1, 0)
weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape
return False return False
...@@ -1156,6 +1152,22 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1156,6 +1152,22 @@ class Transformer(base_converter.ConverterInterface):
return False return False
def update_float_op_data_type(self):
print("update op with float data type")
net = self._model
for op in net.op:
data_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_op_data_type_str)
if not data_type_arg:
data_type_arg = op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = self._option.data_type
elif data_type_arg.i != self._option.data_type \
and data_type_arg.i == mace_pb2.DT_FLOAT:
data_type_arg.i = self._option.data_type
return False
def sort_dfs(self, op, visited, sorted_nodes): def sort_dfs(self, op, visited, sorted_nodes):
visited.update([op.name]) visited.update([op.name])
if len(op.input) > 0: if len(op.input) > 0:
......
...@@ -112,6 +112,8 @@ DSPDataTypeStrs = [ ...@@ -112,6 +112,8 @@ DSPDataTypeStrs = [
DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs], DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
type=str) type=str)
WinogradParameters = [0, 2, 4]
class DefaultValues(object): class DefaultValues(object):
omp_num_threads = -1, omp_num_threads = -1,
...@@ -408,6 +410,12 @@ def format_model_config(flags): ...@@ -408,6 +410,12 @@ def format_model_config(flags):
else: else:
subgraph[YAMLKeyword.validation_inputs_data] = \ subgraph[YAMLKeyword.validation_inputs_data] = \
validation_inputs_data validation_inputs_data
input_ranges = subgraph.get(
YAMLKeyword.input_ranges, [])
if not isinstance(input_ranges, list):
subgraph[YAMLKeyword.input_ranges] = [input_ranges]
else:
subgraph[YAMLKeyword.input_ranges] = input_ranges
for key in [YAMLKeyword.limit_opencl_kernel_time, for key in [YAMLKeyword.limit_opencl_kernel_time,
YAMLKeyword.nnlib_graph_mode, YAMLKeyword.nnlib_graph_mode,
...@@ -417,6 +425,12 @@ def format_model_config(flags): ...@@ -417,6 +425,12 @@ def format_model_config(flags):
if value == "": if value == "":
model_config[key] = 0 model_config[key] = 0
mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters,
ModuleName.YAML_CONFIG,
"'winograd' parameters must be in "
+ str(WinogradParameters) +
". 0 for disable winograd convolution")
weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "") weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
model_config[YAMLKeyword.weight_file_path] = weight_file_path model_config[YAMLKeyword.weight_file_path] = weight_file_path
...@@ -511,7 +525,7 @@ def print_configuration(flags, configs): ...@@ -511,7 +525,7 @@ def print_configuration(flags, configs):
configs[YAMLKeyword.embed_model_data]]) configs[YAMLKeyword.embed_model_data]])
data.append([YAMLKeyword.linkshared, data.append([YAMLKeyword.linkshared,
configs[YAMLKeyword.linkshared]]) configs[YAMLKeyword.linkshared]])
data.append(["Tuning", flags.tuning]) data.append(["Tuning", flags.disable_tuning])
MaceLogger.summary(StringFormatter.table(header, data, title)) MaceLogger.summary(StringFormatter.table(header, data, title))
...@@ -736,7 +750,7 @@ def build_specific_lib(target_abi, target_soc, serial_num, ...@@ -736,7 +750,7 @@ def build_specific_lib(target_abi, target_soc, serial_num,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None)) input_ranges=subgraphs[0][YAMLKeyword.input_ranges])
device_type = parse_device_type(RuntimeType.gpu) device_type = parse_device_type(RuntimeType.gpu)
sh_commands.tuning_run( sh_commands.tuning_run(
...@@ -869,8 +883,8 @@ def build_library(flags): ...@@ -869,8 +883,8 @@ def build_library(flags):
convert_model(configs) convert_model(configs)
generate_library(configs, flags.tuning, generate_library(configs, flags.disable_tuning,
flags.enable_openmp, flags.address_sanitizer) flags.disable_openmp, flags.address_sanitizer)
print_library_summary(configs) print_library_summary(configs)
...@@ -980,7 +994,7 @@ def run_specific_target(flags, configs, target_abi, ...@@ -980,7 +994,7 @@ def run_specific_target(flags, configs, target_abi,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None)) input_ranges=subgraphs[0][YAMLKeyword.input_ranges])
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
runtime_list.extend([RuntimeType.cpu]) runtime_list.extend([RuntimeType.cpu])
...@@ -1129,7 +1143,7 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num): ...@@ -1129,7 +1143,7 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num):
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None)) input_ranges=subgraphs[0][YAMLKeyword.input_ranges])
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
runtime_list.extend([RuntimeType.cpu]) runtime_list.extend([RuntimeType.cpu])
...@@ -1262,13 +1276,13 @@ def parse_args(): ...@@ -1262,13 +1276,13 @@ def parse_args():
help='build model library and test tools') help='build model library and test tools')
build.set_defaults(func=build_library) build.set_defaults(func=build_library)
build.add_argument( build.add_argument(
'--tuning', '--disable_tuning',
action="store_true", action="store_false",
help="whether tuning the parameters for the GPU of specified SoC.") help="Disable tuning the parameters for the GPU of specified SoC.")
build.add_argument( build.add_argument(
"--enable_openmp", "--disable_openmp",
action="store_false", action="store_false",
help="Enable openmp for multiple thread.") help="Disable openmp for multiple thread.")
run = subparsers.add_parser( run = subparsers.add_parser(
'run', 'run',
parents=[all_type_parent_parser, run_bm_parent_parser, parents=[all_type_parent_parser, run_bm_parent_parser,
......
...@@ -486,7 +486,7 @@ def gen_model_code(model_codegen_dir, ...@@ -486,7 +486,7 @@ def gen_model_code(model_codegen_dir,
input_shapes, input_shapes,
dsp_mode, dsp_mode,
embed_model_data, embed_model_data,
fast_conv, winograd,
obfuscate, obfuscate,
model_build_type, model_build_type,
data_type, data_type,
...@@ -512,7 +512,7 @@ def gen_model_code(model_codegen_dir, ...@@ -512,7 +512,7 @@ def gen_model_code(model_codegen_dir,
"--input_shape=%s" % input_shapes, "--input_shape=%s" % input_shapes,
"--dsp_mode=%s" % dsp_mode, "--dsp_mode=%s" % dsp_mode,
"--embed_model_data=%s" % embed_model_data, "--embed_model_data=%s" % embed_model_data,
"--winograd=%s" % fast_conv, "--winograd=%s" % winograd,
"--obfuscate=%s" % obfuscate, "--obfuscate=%s" % obfuscate,
"--output_dir=%s" % model_codegen_dir, "--output_dir=%s" % model_codegen_dir,
"--model_build_type=%s" % model_build_type, "--model_build_type=%s" % model_build_type,
...@@ -525,8 +525,8 @@ def gen_random_input(model_output_dir, ...@@ -525,8 +525,8 @@ def gen_random_input(model_output_dir,
input_nodes, input_nodes,
input_shapes, input_shapes,
input_files, input_files,
input_file_name="model_input", input_ranges,
input_ranges=None): input_file_name="model_input"):
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = common.formatted_file_name( formatted_name = common.formatted_file_name(
input_file_name, input_name) input_file_name, input_name)
...@@ -534,10 +534,7 @@ def gen_random_input(model_output_dir, ...@@ -534,10 +534,7 @@ def gen_random_input(model_output_dir,
sh.rm("%s/%s" % (model_output_dir, formatted_name)) sh.rm("%s/%s" % (model_output_dir, formatted_name))
input_nodes_str = ",".join(input_nodes) input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes) input_shapes_str = ":".join(input_shapes)
if input_ranges: input_ranges_str = ":".join(input_ranges)
input_ranges_str = ":".join(input_ranges)
else:
input_ranges_str = None
generate_input_data("%s/%s" % (model_output_dir, input_file_name), generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str, input_nodes_str,
input_shapes_str, input_shapes_str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册