提交 1617b83f 编写于 作者: L liuqi

Support build once then run on cpu or gpu.

上级 9716a876
...@@ -40,13 +40,19 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry, ...@@ -40,13 +40,19 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
MACE_LATENCY_LOGGER(1, "Constructing SerialNet ", net_def->name()); MACE_LATENCY_LOGGER(1, "Constructing SerialNet ", net_def->name());
for (int idx = 0; idx < net_def->op_size(); ++idx) { for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto &operator_def = net_def->op(idx); const auto &operator_def = net_def->op(idx);
VLOG(3) << "Creating operator " << operator_def.name() << "(" // TODO(liuqi): refactor based on PB
<< operator_def.type() << ")"; const int op_device =
OperatorDef temp_def(operator_def); ArgumentHelper::GetSingleArgument<OperatorDef, int>(
std::unique_ptr<OperatorBase> op( operator_def, "device", -1);
op_registry->CreateOperator(temp_def, ws, type, mode)); if (op_device == type) {
if (op) { VLOG(3) << "Creating operator " << operator_def.name() << "("
operators_.emplace_back(std::move(op)); << operator_def.type() << ")";
OperatorDef temp_def(operator_def);
std::unique_ptr<OperatorBase> op(
op_registry->CreateOperator(temp_def, ws, type, mode));
if (op) {
operators_.emplace_back(std::move(op));
}
} }
} }
} }
......
...@@ -136,7 +136,11 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ...@@ -136,7 +136,11 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
// As DSP may have different data output type for each op, // As DSP may have different data output type for each op,
// we stick to the same concept. // we stick to the same concept.
for (auto &op : net_def.op()) { for (auto &op : net_def.op()) {
if (!op.mem_id().empty()) { // TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "device", -1);
if (op_device == device_type && !op.mem_id().empty()) {
const DataType op_dtype = static_cast<DataType>( const DataType op_dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT))); op, "T", static_cast<int>(DT_FLOAT)));
...@@ -150,20 +154,29 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ...@@ -150,20 +154,29 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid."); MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid.");
for (auto &mem_block : net_def.mem_arena().mem_block()) { for (auto &mem_block : net_def.mem_arena().mem_block()) {
if (device_type == DeviceType::GPU) { if (device_type == DeviceType::GPU) {
std::unique_ptr<BufferBase> image_buf( // TODO(liuqi): refactor based on PB
new Image({mem_block.x(), mem_block.y()}, dtype)); if (mem_block.mem_id() >= 20000) {
preallocated_allocator_.SetBuffer(mem_block.mem_id(), std::unique_ptr<BufferBase> image_buf(
std::move(image_buf)); new Image({mem_block.x(), mem_block.y()}, dtype));
preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(image_buf));
}
} else { } else {
std::unique_ptr<BufferBase> tensor_buf( if (mem_block.mem_id() < 20000) {
new Buffer(GetDeviceAllocator(device_type), mem_block.x())); std::unique_ptr<BufferBase> tensor_buf(
preallocated_allocator_.SetBuffer(mem_block.mem_id(), new Buffer(GetDeviceAllocator(device_type), mem_block.x()));
std::move(tensor_buf)); preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(tensor_buf));
}
} }
} }
VLOG(3) << "Preallocate buffer to tensors"; VLOG(3) << "Preallocate buffer to tensors";
for (auto &op : net_def.op()) { for (auto &op : net_def.op()) {
if (!op.mem_id().empty()) { // TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "device", -1);
if (op_device == device_type && !op.mem_id().empty()) {
auto mem_ids = op.mem_id(); auto mem_ids = op.mem_id();
int count = mem_ids.size(); int count = mem_ids.size();
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
......
...@@ -16,6 +16,7 @@ import argparse ...@@ -16,6 +16,7 @@ import argparse
import sys import sys
import hashlib import hashlib
import os.path import os.path
import copy
from mace.proto import mace_pb2 from mace.proto import mace_pb2
from mace.python.tools import tf_dsp_converter_lib from mace.python.tools import tf_dsp_converter_lib
...@@ -25,6 +26,7 @@ from mace.python.tools.converter_tool import base_converter as cvt ...@@ -25,6 +26,7 @@ from mace.python.tools.converter_tool import base_converter as cvt
from mace.python.tools.converter_tool import tensorflow_converter from mace.python.tools.converter_tool import tensorflow_converter
from mace.python.tools.converter_tool import caffe_converter from mace.python.tools.converter_tool import caffe_converter
from mace.python.tools.converter_tool import transformer from mace.python.tools.converter_tool import transformer
from mace.python.tools.convert_util import mace_check
# ./bazel-bin/mace/python/tools/tf_converter --model_file quantized_test.pb \ # ./bazel-bin/mace/python/tools/tf_converter --model_file quantized_test.pb \
...@@ -34,11 +36,14 @@ from mace.python.tools.converter_tool import transformer ...@@ -34,11 +36,14 @@ from mace.python.tools.converter_tool import transformer
FLAGS = None FLAGS = None
data_type_map = {'DT_HALF': mace_pb2.DT_HALF,
'DT_FLOAT': mace_pb2.DT_FLOAT}
device_type_map = {'cpu': mace_pb2.CPU, device_type_map = {'cpu': mace_pb2.CPU,
'gpu': mace_pb2.GPU, 'gpu': mace_pb2.GPU,
'dsp': mace_pb2.HEXAGON} 'dsp': mace_pb2.HEXAGON}
device_data_type_map = {
mace_pb2.CPU: mace_pb2.DT_FLOAT,
mace_pb2.GPU: mace_pb2.DT_HALF,
mace_pb2.HEXAGON: mace_pb2.DT_UINT8
}
def file_checksum(fname): def file_checksum(fname):
...@@ -81,7 +86,7 @@ def main(unused_args): ...@@ -81,7 +86,7 @@ def main(unused_args):
if FLAGS.platform not in ['tensorflow', 'caffe']: if FLAGS.platform not in ['tensorflow', 'caffe']:
print ("platform %s is not supported." % FLAGS.platform) print ("platform %s is not supported." % FLAGS.platform)
sys.exit(-1) sys.exit(-1)
if FLAGS.runtime not in ['cpu', 'gpu', 'dsp']: if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', '']:
print ("runtime %s is not supported." % FLAGS.runtime) print ("runtime %s is not supported." % FLAGS.runtime)
sys.exit(-1) sys.exit(-1)
...@@ -95,8 +100,6 @@ def main(unused_args): ...@@ -95,8 +100,6 @@ def main(unused_args):
sys.exit(-1) sys.exit(-1)
else: else:
option = cvt.ConverterOption() option = cvt.ConverterOption()
option.data_type = data_type_map[FLAGS.data_type]
option.device = device_type_map[FLAGS.runtime]
option.winograd_enabled = bool(FLAGS.winograd) option.winograd_enabled = bool(FLAGS.winograd)
input_node_names = FLAGS.input_node.split(',') input_node_names = FLAGS.input_node.split(',')
...@@ -117,8 +120,8 @@ def main(unused_args): ...@@ -117,8 +120,8 @@ def main(unused_args):
print("Convert model to mace model.") print("Convert model to mace model.")
if FLAGS.platform == 'tensorflow': if FLAGS.platform == 'tensorflow':
converter = tensorflow_converter.TensorflowConverter(option, converter = tensorflow_converter.TensorflowConverter(
FLAGS.model_file) # noqa option, FLAGS.model_file)
elif FLAGS.platform == 'caffe': elif FLAGS.platform == 'caffe':
converter = caffe_converter.CaffeConverter(option, converter = caffe_converter.CaffeConverter(option,
FLAGS.model_file, FLAGS.model_file,
...@@ -126,16 +129,49 @@ def main(unused_args): ...@@ -126,16 +129,49 @@ def main(unused_args):
output_graph_def = converter.run() output_graph_def = converter.run()
print("Transform model to one that can better run on device.") print("Transform model to one that can better run on device.")
# TODO(liuqi/liyin): transform gpu/cpu and merge their ops if not FLAGS.runtime:
mace_transformer = transformer.Transformer(option, output_graph_def) cpu_graph_def = copy.deepcopy(output_graph_def)
output_graph_def = mace_transformer.run() option.device = mace_pb2.CPU
option.data_type = device_data_type_map[mace_pb2.CPU]
option.disable_transpose_filters()
mace_cpu_transformer = transformer.Transformer(
option, cpu_graph_def)
cpu_graph_def = mace_cpu_transformer.run()
print "start optimize cpu memory."
memory_optimizer.optimize_cpu_memory(cpu_graph_def)
print "CPU memory optimization done."
print "start optimize memory." option.device = mace_pb2.GPU
if FLAGS.runtime == 'gpu': option.data_type = device_data_type_map[mace_pb2.GPU]
memory_optimizer.optimize_gpu_memory(output_graph_def) option.enable_transpose_filters()
elif FLAGS.runtime == 'cpu': mace_gpu_transformer = transformer.Transformer(
memory_optimizer.optimize_cpu_memory(output_graph_def) option, output_graph_def)
print "Memory optimization done." output_gpu_graph_def = mace_gpu_transformer.run()
print "start optimize gpu memory."
memory_optimizer.optimize_gpu_memory(output_gpu_graph_def)
print "GPU memory optimization done."
print "Merge cpu and gpu ops together"
output_graph_def.op.extend(cpu_graph_def.op)
output_graph_def.mem_arena.mem_block.extend(
cpu_graph_def.mem_arena.mem_block)
print "Merge done"
else:
option.device = device_type_map[FLAGS.runtime]
option.data_type = device_data_type_map[option.device]
mace_transformer = transformer.Transformer(
option, output_graph_def)
output_graph_def = mace_transformer.run()
print "start optimize memory."
if FLAGS.runtime == 'gpu':
memory_optimizer.optimize_gpu_memory(output_graph_def)
elif FLAGS.runtime == 'cpu':
memory_optimizer.optimize_cpu_memory(output_graph_def)
else:
mace_check(False, "runtime only support [gpu|cpu|dsp]")
print "Memory optimization done."
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source( source_converter_lib.convert_to_source(
...@@ -188,7 +224,7 @@ def parse_args(): ...@@ -188,7 +224,7 @@ def parse_args():
default="", default="",
help="File to save the output graph to.") help="File to save the output graph to.")
parser.add_argument( parser.add_argument(
"--runtime", type=str, default="cpu", help="Runtime: cpu/gpu/dsp") "--runtime", type=str, default="", help="Runtime: cpu/gpu/dsp")
parser.add_argument( parser.add_argument(
"--input_node", "--input_node",
type=str, type=str,
...@@ -196,11 +232,6 @@ def parse_args(): ...@@ -196,11 +232,6 @@ def parse_args():
help="e.g., input_node") help="e.g., input_node")
parser.add_argument( parser.add_argument(
"--output_node", type=str, default="softmax", help="e.g., softmax") "--output_node", type=str, default="softmax", help="e.g., softmax")
parser.add_argument(
"--data_type",
type=str,
default='DT_FLOAT',
help="e.g., DT_HALF/DT_FLOAT")
parser.add_argument( parser.add_argument(
"--output_type", type=str, default="pb", help="output type: source/pb") "--output_type", type=str, default="pb", help="output type: source/pb")
parser.add_argument( parser.add_argument(
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum from enum import Enum
from mace.proto import mace_pb2 from mace.proto import mace_pb2
...@@ -117,6 +132,27 @@ class MaceKeyword(object): ...@@ -117,6 +132,27 @@ class MaceKeyword(object):
mace_axis_str = 'axis' mace_axis_str = 'axis'
mace_shape_str = 'shape' mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed' mace_winograd_filter_transformed = 'is_filter_transformed'
mace_device = 'device'
class TransformerRule(Enum):
REMOVE_IDENTITY_OP = 0
TRANSFORM_GLOBAL_POOLING = 1
FOLD_SOFTMAX = 2
FOLD_BATCHNORM = 3,
FOLD_CONV_AND_BN = 4,
FOLD_DEPTHWISE_CONV_AND_BN = 5,
TRANSFORM_GPU_WINOGRAD = 6,
TRANSFORM_ADD_TO_BIASADD = 7,
FOLD_BIASADD = 8,
FOLD_ACTIVATION = 9,
TRANSPOSE_FILTERS = 10,
RESHAPE_FC_WEIGHT = 11,
TRANSPOSE_DATA_FORMAT = 12,
TRANSFORM_GLOBAL_CONV_TO_FC = 13,
TRANSFORM_BUFFER_IMAGE = 14,
ADD_DEVICE_AND_DATA_TYPE = 15,
SORT_BY_EXECUTION = 16
class ConverterInterface(object): class ConverterInterface(object):
...@@ -162,6 +198,25 @@ class ConverterOption(object): ...@@ -162,6 +198,25 @@ class ConverterOption(object):
self._data_type = mace_pb2.DT_FLOAT self._data_type = mace_pb2.DT_FLOAT
self._device = mace_pb2.CPU self._device = mace_pb2.CPU
self._winograd_enabled = False self._winograd_enabled = False
self._transformer_option = [
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_SOFTMAX,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.TRANSFORM_GPU_WINOGRAD,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.FOLD_BIASADD,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE,
TransformerRule.SORT_BY_EXECUTION,
]
@property @property
def input_nodes(self): def input_nodes(self):
...@@ -183,6 +238,10 @@ class ConverterOption(object): ...@@ -183,6 +238,10 @@ class ConverterOption(object):
def winograd_enabled(self): def winograd_enabled(self):
return self._winograd_enabled return self._winograd_enabled
@property
def transformer_option(self):
return self._transformer_option
@input_nodes.setter @input_nodes.setter
def input_nodes(self, input_nodes): def input_nodes(self, input_nodes):
for node in input_nodes: for node in input_nodes:
...@@ -211,6 +270,14 @@ class ConverterOption(object): ...@@ -211,6 +270,14 @@ class ConverterOption(object):
def winograd_enabled(self, winograd_enabled): def winograd_enabled(self, winograd_enabled):
self._winograd_enabled = winograd_enabled self._winograd_enabled = winograd_enabled
def disable_transpose_filters(self):
if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)
def enable_transpose_filters(self):
if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)
class ConverterUtil(object): class ConverterUtil(object):
@staticmethod @staticmethod
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import numpy as np import numpy as np
import google.protobuf.text_format import google.protobuf.text_format
...@@ -325,10 +340,6 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -325,10 +340,6 @@ class CaffeConverter(base_converter.ConverterInterface):
op.input.extend(caffe_op.layer.bottom) op.input.extend(caffe_op.layer.bottom)
op.output.extend(caffe_op.layer.top) op.output.extend(caffe_op.layer.top)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
return op return op
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import numpy as np import numpy as np
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -197,11 +212,6 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -197,11 +212,6 @@ class TensorflowConverter(base_converter.ConverterInterface):
for tf_output in tf_op.outputs: for tf_output in tf_op.outputs:
output_shape = op.output_shape.add() output_shape = op.output_shape.add()
output_shape.dims.extend(tf_output.shape.as_list()) output_shape.dims.extend(tf_output.shape.as_list())
op.output_type.append(self._option.data_type)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
...@@ -289,7 +299,6 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -289,7 +299,6 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.input.extend([scale_name, offset_name]) op.input.extend([scale_name, offset_name])
del op.output[1:] del op.output[1:]
del op.output_shape[1:] del op.output_shape[1:]
del op.output_type[1:]
def convert_pooling(self, tf_op): def convert_pooling(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum import enum
import numpy as np import numpy as np
...@@ -11,6 +26,7 @@ from mace.python.tools.converter_tool.base_converter import FilterFormat ...@@ -11,6 +26,7 @@ from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.converter_tool.base_converter import TransformerRule
from mace.python.tools.convert_util import mace_check from mace.python.tools.convert_util import mace_check
OPENCL_IMAGE_MAX_SIZE = 16384 OPENCL_IMAGE_MAX_SIZE = 16384
...@@ -36,23 +52,52 @@ class Transformer(base_converter.ConverterInterface): ...@@ -36,23 +52,52 @@ class Transformer(base_converter.ConverterInterface):
def __init__(self, option, model): def __init__(self, option, model):
# DO NOT reorder the following transformers # DO NOT reorder the following transformers
self._registered_transformers = [ self._registered_transformers_order = [
self.remove_identity_op, TransformerRule.REMOVE_IDENTITY_OP,
self.transform_global_pooling, TransformerRule.TRANSFORM_GLOBAL_POOLING,
self.fold_softmax, TransformerRule.FOLD_SOFTMAX,
self.fold_batchnorm, TransformerRule.FOLD_BATCHNORM,
self.fold_conv_and_bn, # data_format related TransformerRule.FOLD_CONV_AND_BN,
self.fold_depthwise_conv_and_bn, # data_format related TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
self.transform_gpu_winograd, # data_format related TransformerRule.TRANSFORM_GPU_WINOGRAD,
self.transform_add_to_biasadd, TransformerRule.TRANSFORM_ADD_TO_BIASADD,
self.fold_biasadd, TransformerRule.FOLD_BIASADD,
self.fold_activation, TransformerRule.FOLD_ACTIVATION,
self.transpose_filters, TransformerRule.TRANSPOSE_FILTERS,
self.transpose_data_format, TransformerRule.RESHAPE_FC_WEIGHT,
self.transform_global_conv_to_fc, TransformerRule.TRANSPOSE_DATA_FORMAT,
self.transform_buffer_image, TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
self.sort_by_execution, TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE,
TransformerRule.SORT_BY_EXECUTION,
] ]
self._registered_transformers = {
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling,
TransformerRule.FOLD_SOFTMAX: self.fold_softmax,
TransformerRule.FOLD_BATCHNORM: self.fold_batchnorm,
TransformerRule.FOLD_CONV_AND_BN:
self.fold_conv_and_bn, # data_format related
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN:
self.fold_depthwise_conv_and_bn, # data_format related
TransformerRule.TRANSFORM_GPU_WINOGRAD:
self.transform_gpu_winograd, # data_format related
TransformerRule.TRANSFORM_ADD_TO_BIASADD:
self.transform_add_to_biasadd,
TransformerRule.FOLD_BIASADD: self.fold_biasadd,
TransformerRule.FOLD_ACTIVATION: self.fold_activation,
TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters,
TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight,
TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC:
self.transform_global_conv_to_fc,
TransformerRule.TRANSFORM_BUFFER_IMAGE:
self.transform_buffer_image,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE:
self.add_device_and_data_type,
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
}
self._option = option self._option = option
self._model = model self._model = model
...@@ -67,12 +112,14 @@ class Transformer(base_converter.ConverterInterface): ...@@ -67,12 +112,14 @@ class Transformer(base_converter.ConverterInterface):
self._target_data_format = DataFormat.NCHW self._target_data_format = DataFormat.NCHW
def run(self): def run(self):
for transformer in self._registered_transformers: for key in self._registered_transformers_order:
while True: if key in self._option.transformer_option:
self.construct_ops_and_consumers() transformer = self._registered_transformers[key]
changed = transformer() while True:
if not changed: self.construct_ops_and_consumers()
break changed = transformer()
if not changed:
break
return self._model return self._model
...@@ -404,19 +451,16 @@ class Transformer(base_converter.ConverterInterface): ...@@ -404,19 +451,16 @@ class Transformer(base_converter.ConverterInterface):
wt_output_shape.dims.extend( wt_output_shape.dims.extend(
[16, in_channels, wt_output_width, 1]) [16, in_channels, wt_output_width, 1])
arg = wt_op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
if ConverterUtil.get_arg(op, if ConverterUtil.get_arg(op,
MaceKeyword.mace_padding_str) \ MaceKeyword.mace_padding_str) \
is not None: is not None:
padding_arg = wt_op.arg.add() padding_arg = wt_op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_str padding_arg.name = MaceKeyword.mace_padding_str
padding_arg.i = ConverterUtil.get_arg(op, padding_arg.i = ConverterUtil.get_arg(
MaceKeyword.mace_padding_str).i # noqa op, MaceKeyword.mace_padding_str).i
elif ConverterUtil.get_arg(op, elif ConverterUtil.get_arg(
MaceKeyword.mace_padding_values_str) is not None: # noqa op, MaceKeyword.mace_padding_values_str)\
is not None:
padding_arg = wt_op.arg.add() padding_arg = wt_op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str padding_arg.name = MaceKeyword.mace_padding_values_str
padding_arg.ints.extend(ConverterUtil.get_arg( padding_arg.ints.extend(ConverterUtil.get_arg(
...@@ -432,9 +476,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -432,9 +476,6 @@ class Transformer(base_converter.ConverterInterface):
matmul_output_shape.dims.extend( matmul_output_shape.dims.extend(
[16, out_channels, wt_output_width, 1]) [16, out_channels, wt_output_width, 1])
arg = matmul_op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
arg = matmul_op.arg.add() arg = matmul_op.arg.add()
arg.name = MaceKeyword.mace_winograd_filter_transformed arg.name = MaceKeyword.mace_winograd_filter_transformed
arg.i = 1 arg.i = 1
...@@ -451,9 +492,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -451,9 +492,6 @@ class Transformer(base_converter.ConverterInterface):
iwt_output_shape = iwt_op.output_shape.add() iwt_output_shape = iwt_op.output_shape.add()
iwt_output_shape.dims.extend(op.output_shape[0].dims) iwt_output_shape.dims.extend(op.output_shape[0].dims)
arg = iwt_op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
batch_arg = iwt_op.arg.add() batch_arg = iwt_op.arg.add()
batch_arg.name = 'batch' batch_arg.name = 'batch'
batch_arg.i = batch batch_arg.i = batch
...@@ -618,10 +656,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -618,10 +656,6 @@ class Transformer(base_converter.ConverterInterface):
dims_arg.name = MaceKeyword.mace_dims_str dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 3, 1, 2]) dims_arg.ints.extend([0, 3, 1, 2])
arg = op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
for output_node in self._option.output_nodes.values(): for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \ output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name + '_' + output_node.name
...@@ -639,75 +673,43 @@ class Transformer(base_converter.ConverterInterface): ...@@ -639,75 +673,43 @@ class Transformer(base_converter.ConverterInterface):
dims_arg.name = MaceKeyword.mace_dims_str dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 2, 3, 1]) dims_arg.ints.extend([0, 2, 3, 1])
arg = op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
return False return False
def transpose_filters(self): def transpose_filters(self):
net = self._model net = self._model
filter_format = self.filter_format() filter_format = self.filter_format()
# TODO(liyin/liuqi): remove this if-condition after combine cpu/gpu print("Transpose filters to OIHW")
if self._option.device == mace_pb2.CPU: # transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM)
print("Transpose filters to OIHW") if filter_format == FilterFormat.HWIO:
# transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM)
if filter_format == FilterFormat.HWIO:
for op in net.op:
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
if ConverterUtil.get_arg(op,
MaceKeyword.mace_winograd_filter_transformed) is None: # noqa
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
self.set_filter_format(FilterFormat.OIHW)
elif self._option.device == mace_pb2.GPU:
# TODO(liyin/liuqi): remove this whole logic after combine cpu/gpu
print("Transpose filters to HWOI/HWIM")
for op in net.op: for op in net.op:
if op.type == MaceOp.Conv2D.name \ if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name \ or op.type == MaceOp.Deconv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name: or op.type == MaceOp.DepthwiseConv2d.name:
filter = self._consts[op.input[1]] if ConverterUtil.get_arg(
filter_data = np.array(filter.float_data).reshape( op, MaceKeyword.mace_winograd_filter_transformed)\
filter.dims) is None:
# transpose filter to HWOI/HWIM for filter = self._consts[op.input[1]]
# tensorflow and caffe (OIHW/MIHW) filter_data = np.array(filter.float_data).reshape(
if filter_format == FilterFormat.HWIO \ filter.dims)
and (op.type == MaceOp.Conv2D.name filter_data = filter_data.transpose(3, 2, 0, 1)
or op.type == MaceOp.Deconv2D.name):
filter_data = filter_data.transpose(0, 1, 3, 2)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
elif filter_format == FilterFormat.OIHW: self.set_filter_format(FilterFormat.OIHW)
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name: return False
filter_data = filter_data.transpose(2, 3, 0, 1)
filter.float_data[:] = filter_data.flat def reshape_fc_weight(self):
filter.dims[:] = filter_data.shape net = self._model
elif op.type == MaceOp.DepthwiseConv2d.name: for op in net.op:
filter_data = filter_data.transpose(2, 3, 1, 0) if op.type == MaceOp.FullyConnected.name:
filter.float_data[:] = filter_data.flat weight = self._consts[op.input[1]]
filter.dims[:] = filter_data.shape # NCHW
input_shape = list(self._producer[op.input[0]]
if op.type == MaceOp.FullyConnected.name: .output_shape[0].dims)
weight = self._consts[op.input[1]] weight_shape = [weight.dims[0]] + input_shape[1:]
input_shape = list(self._producer[op.input[0]] del weight.dims[:]
.output_shape[0].dims) weight.dims.extend(weight_shape)
weight_shape = [weight.dims[0]] + input_shape[1:]
# OCHW -> OHWC
weight_data = np.array(weight.float_data).reshape(
weight_shape)
weight_data = weight_data.transpose(0, 2, 3, 1)
weight.float_data[:] = weight_data.flat
self.set_filter_format(FilterFormat.HWOI)
return False return False
...@@ -727,9 +729,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -727,9 +729,6 @@ class Transformer(base_converter.ConverterInterface):
arg = op_def.arg.add() arg = op_def.arg.add()
arg.name = MaceKeyword.mace_mode arg.name = MaceKeyword.mace_mode
arg.i = 0 arg.i = 0
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
op.input[input_idx] = output_name op.input[input_idx] = output_name
...@@ -788,9 +787,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -788,9 +787,6 @@ class Transformer(base_converter.ConverterInterface):
arg = op_def.arg.add() arg = op_def.arg.add()
arg.name = MaceKeyword.mace_buffer_type arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
for output_node in self._option.output_nodes.values(): for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \ output_name = MaceKeyword.mace_output_node_name \
...@@ -806,9 +802,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -806,9 +802,6 @@ class Transformer(base_converter.ConverterInterface):
arg = op_def.arg.add() arg = op_def.arg.add()
arg.name = MaceKeyword.mace_buffer_type arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
return False return False
...@@ -885,6 +878,19 @@ class Transformer(base_converter.ConverterInterface): ...@@ -885,6 +878,19 @@ class Transformer(base_converter.ConverterInterface):
in_channels * filter_width in_channels * filter_width
* filter_height][:] * filter_height][:]
def add_device_and_data_type(self):
# TODO(liuqi) add device definition in OperatorDef
net = self._model
for op in net.op:
arg = op.arg.add()
arg.name = MaceKeyword.mace_device
arg.i = self._option.device
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
return False
def sort_dfs(self, op, visited, sorted_nodes): def sort_dfs(self, op, visited, sorted_nodes):
visited.update([op.name]) visited.update([op.name])
if len(op.input) > 0: if len(op.input) > 0:
......
...@@ -167,7 +167,6 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir, ...@@ -167,7 +167,6 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
tensor_info=tensor_info, tensor_info=tensor_info,
tensor=t, tensor=t,
tag=model_tag, tag=model_tag,
runtime=runtime,
offset=offset, offset=offset,
) )
model_data.extend(tensor_info.data) model_data.extend(tensor_info.data)
......
...@@ -55,6 +55,7 @@ void BufferToImage(const std::string &input_name, ...@@ -55,6 +55,7 @@ void BufferToImage(const std::string &input_name,
const std::string &output_name, const std::string &output_name,
const int buffer_type, const int buffer_type,
const std::vector<int> &mem_ids, const std::vector<int> &mem_ids,
const DeviceType device_type,
NetDef *net_def, NetDef *net_def,
const int mode = NetMode::NORMAL) { const int mode = NetMode::NORMAL) {
OperatorDef operator_def; OperatorDef operator_def;
...@@ -64,6 +65,7 @@ void BufferToImage(const std::string &input_name, ...@@ -64,6 +65,7 @@ void BufferToImage(const std::string &input_name,
.Output(output_name) .Output(output_name)
.AddIntArg("buffer_type", buffer_type) .AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.AddIntArg("mode", mode) .AddIntArg("mode", mode)
.Finalize(&operator_def); .Finalize(&operator_def);
...@@ -76,6 +78,7 @@ template <typename T> ...@@ -76,6 +78,7 @@ template <typename T>
void ImageToBuffer(const std::string &input_name, void ImageToBuffer(const std::string &input_name,
const std::string &output_name, const std::string &output_name,
const int buffer_type, const int buffer_type,
const DeviceType device_type,
NetDef *net_def) { NetDef *net_def) {
OperatorDef operator_def; OperatorDef operator_def;
...@@ -84,6 +87,7 @@ void ImageToBuffer(const std::string &input_name, ...@@ -84,6 +87,7 @@ void ImageToBuffer(const std::string &input_name,
.Output(output_name) .Output(output_name)
.AddIntArg("buffer_type", buffer_type) .AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def); .Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def); net_def->add_op()->CopyFrom(operator_def);
...@@ -94,6 +98,7 @@ void Conv3x3(const std::string &input_name, ...@@ -94,6 +98,7 @@ void Conv3x3(const std::string &input_name,
const std::string &filter_name, const std::string &filter_name,
const std::string &output_name, const std::string &output_name,
const std::vector<int> &mem_ids, const std::vector<int> &mem_ids,
const DeviceType device_type,
NetDef *net_def) { NetDef *net_def) {
OperatorDef operator_def; OperatorDef operator_def;
ops::test::OpDefBuilder("Conv2D", "Conv2dOp") ops::test::OpDefBuilder("Conv2D", "Conv2dOp")
...@@ -104,6 +109,7 @@ void Conv3x3(const std::string &input_name, ...@@ -104,6 +109,7 @@ void Conv3x3(const std::string &input_name,
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def); .Finalize(&operator_def);
operator_def.set_mem_id(mem_ids); operator_def.set_mem_id(mem_ids);
...@@ -113,6 +119,7 @@ void Conv3x3(const std::string &input_name, ...@@ -113,6 +119,7 @@ void Conv3x3(const std::string &input_name,
template <typename T> template <typename T>
void Relu(const std::string &input_name, void Relu(const std::string &input_name,
const std::string &output_name, const std::string &output_name,
const DeviceType device_type,
NetDef *net_def) { NetDef *net_def) {
OperatorDef operator_def; OperatorDef operator_def;
ops::test::OpDefBuilder("Activation", "ReluTest") ops::test::OpDefBuilder("Activation", "ReluTest")
...@@ -120,6 +127,7 @@ void Relu(const std::string &input_name, ...@@ -120,6 +127,7 @@ void Relu(const std::string &input_name,
.Output(output_name) .Output(output_name)
.AddStringArg("activation", "RELU") .AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def); .Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def); net_def->add_op()->CopyFrom(operator_def);
...@@ -195,7 +203,8 @@ std::map<std::string, int> AddMemoryOptimization( ...@@ -195,7 +203,8 @@ std::map<std::string, int> AddMemoryOptimization(
const std::vector<std::vector<int64_t>> &output_shapes, const std::vector<std::vector<int64_t>> &output_shapes,
NetDef *net_def) { NetDef *net_def) {
std::map<std::string, int> res; std::map<std::string, int> res;
int mem_id = 0; // TODO(liuqi) refactor based on PB
int mem_id = 20000;
size_t input_shape_size = input_shapes.size(); size_t input_shape_size = input_shapes.size();
uint32_t in_mem_block_x = 0; uint32_t in_mem_block_x = 0;
uint32_t in_mem_block_y = 0; uint32_t in_mem_block_y = 0;
...@@ -269,21 +278,25 @@ void MaceRunFunc(const int in_out_size) { ...@@ -269,21 +278,25 @@ void MaceRunFunc(const int in_out_size) {
BufferToImage<half>(input_name, input_names[i], BufferToImage<half>(input_name, input_names[i],
mace::kernels::IN_OUT_CHANNEL, mace::kernels::IN_OUT_CHANNEL,
{mem_map[input_names[i]]}, {mem_map[input_names[i]]},
device,
&net_def); &net_def);
} }
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name, BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {}, mace::kernels::CONV2D_FILTER, {}, device,
&net_def, NetMode::INIT); &net_def, NetMode::INIT);
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
Conv3x3<half>(input_names[i], filter_tensor_img_name, Conv3x3<half>(input_names[i], filter_tensor_img_name,
output_names[i], {mem_map[output_names[i]]}, output_names[i], {mem_map[output_names[i]]},
device,
&net_def); &net_def);
} }
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
std::string output_name = MakeString("mace_output_node_", std::string output_name = MakeString("mace_output_node_",
output_names[i]); output_names[i]);
ImageToBuffer<float>(output_names[i], output_name, ImageToBuffer<float>(output_names[i], output_name,
mace::kernels::IN_OUT_CHANNEL, &net_def); mace::kernels::IN_OUT_CHANNEL,
device,
&net_def);
} }
const std::string file_path ="/data/local/tmp/mace"; const std::string file_path ="/data/local/tmp/mace";
......
...@@ -65,6 +65,7 @@ void BufferToImage(const std::string &input_name, ...@@ -65,6 +65,7 @@ void BufferToImage(const std::string &input_name,
const std::string &output_name, const std::string &output_name,
const int buffer_type, const int buffer_type,
const std::vector<int> &mem_ids, const std::vector<int> &mem_ids,
const DeviceType device_type,
NetDef *net_def, NetDef *net_def,
const int mode = NetMode::NORMAL) { const int mode = NetMode::NORMAL) {
OperatorDef operator_def; OperatorDef operator_def;
...@@ -74,6 +75,7 @@ void BufferToImage(const std::string &input_name, ...@@ -74,6 +75,7 @@ void BufferToImage(const std::string &input_name,
.Output(output_name) .Output(output_name)
.AddIntArg("buffer_type", buffer_type) .AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.AddIntArg("mode", mode) .AddIntArg("mode", mode)
.Finalize(&operator_def); .Finalize(&operator_def);
...@@ -86,6 +88,7 @@ template <typename T> ...@@ -86,6 +88,7 @@ template <typename T>
void ImageToBuffer(const std::string &input_name, void ImageToBuffer(const std::string &input_name,
const std::string &output_name, const std::string &output_name,
const int buffer_type, const int buffer_type,
const DeviceType device_type,
NetDef *net_def) { NetDef *net_def) {
OperatorDef operator_def; OperatorDef operator_def;
...@@ -94,6 +97,7 @@ void ImageToBuffer(const std::string &input_name, ...@@ -94,6 +97,7 @@ void ImageToBuffer(const std::string &input_name,
.Output(output_name) .Output(output_name)
.AddIntArg("buffer_type", buffer_type) .AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def); .Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def); net_def->add_op()->CopyFrom(operator_def);
...@@ -104,6 +108,7 @@ void Conv3x3(const std::string &input_name, ...@@ -104,6 +108,7 @@ void Conv3x3(const std::string &input_name,
const std::string &filter_name, const std::string &filter_name,
const std::string &output_name, const std::string &output_name,
const std::vector<int> &mem_ids, const std::vector<int> &mem_ids,
const DeviceType device_type,
NetDef *net_def) { NetDef *net_def) {
OperatorDef operator_def; OperatorDef operator_def;
ops::test::OpDefBuilder("Conv2D", "Conv2dOp") ops::test::OpDefBuilder("Conv2D", "Conv2dOp")
...@@ -114,6 +119,7 @@ void Conv3x3(const std::string &input_name, ...@@ -114,6 +119,7 @@ void Conv3x3(const std::string &input_name,
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def); .Finalize(&operator_def);
operator_def.set_mem_id(mem_ids); operator_def.set_mem_id(mem_ids);
...@@ -123,6 +129,7 @@ void Conv3x3(const std::string &input_name, ...@@ -123,6 +129,7 @@ void Conv3x3(const std::string &input_name,
template <typename T> template <typename T>
void Relu(const std::string &input_name, void Relu(const std::string &input_name,
const std::string &output_name, const std::string &output_name,
const DeviceType device_type,
NetDef *net_def) { NetDef *net_def) {
OperatorDef operator_def; OperatorDef operator_def;
ops::test::OpDefBuilder("Activation", "ReluTest") ops::test::OpDefBuilder("Activation", "ReluTest")
...@@ -130,6 +137,7 @@ void Relu(const std::string &input_name, ...@@ -130,6 +137,7 @@ void Relu(const std::string &input_name,
.Output(output_name) .Output(output_name)
.AddStringArg("activation", "RELU") .AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def); .Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def); net_def->add_op()->CopyFrom(operator_def);
...@@ -205,7 +213,8 @@ std::map<std::string, int> AddMemoryOptimization( ...@@ -205,7 +213,8 @@ std::map<std::string, int> AddMemoryOptimization(
const std::vector<std::vector<int64_t>> &output_shapes, const std::vector<std::vector<int64_t>> &output_shapes,
NetDef *net_def) { NetDef *net_def) {
std::map<std::string, int> res; std::map<std::string, int> res;
int mem_id = 0; // TODO(liuqi) refactor based on PB
int mem_id = 20000;
size_t input_shape_size = input_shapes.size(); size_t input_shape_size = input_shapes.size();
uint32_t in_mem_block_x = 0; uint32_t in_mem_block_x = 0;
uint32_t in_mem_block_y = 0; uint32_t in_mem_block_y = 0;
...@@ -279,21 +288,24 @@ void MaceRun(const int in_out_size, ...@@ -279,21 +288,24 @@ void MaceRun(const int in_out_size,
BufferToImage<half>(input_name, input_names[i], BufferToImage<half>(input_name, input_names[i],
mace::kernels::IN_OUT_CHANNEL, mace::kernels::IN_OUT_CHANNEL,
{mem_map[input_names[i]]}, {mem_map[input_names[i]]},
device,
&net_def); &net_def);
} }
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name, BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {}, mace::kernels::CONV2D_FILTER, {}, device,
&net_def, NetMode::INIT); &net_def, NetMode::INIT);
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
Conv3x3<half>(input_names[i], filter_tensor_img_name, Conv3x3<half>(input_names[i], filter_tensor_img_name,
output_names[i], {mem_map[output_names[i]]}, output_names[i], {mem_map[output_names[i]]},
&net_def); device, &net_def);
} }
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
std::string output_name = MakeString("mace_output_node_", std::string output_name = MakeString("mace_output_node_",
output_names[i]); output_names[i]);
ImageToBuffer<float>(output_names[i], output_name, ImageToBuffer<float>(output_names[i], output_name,
mace::kernels::IN_OUT_CHANNEL, &net_def); mace::kernels::IN_OUT_CHANNEL,
device,
&net_def);
} }
MaceEngine engine(&net_def, device, input_names, output_names); MaceEngine engine(&net_def, device, input_names, output_names);
......
...@@ -62,27 +62,23 @@ def get_target_socs(configs): ...@@ -62,27 +62,23 @@ def get_target_socs(configs):
return target_socs return target_socs
def get_data_and_device_type(runtime): def parse_device_type(runtime):
data_type = ""
device_type = "" device_type = ""
if runtime == "dsp": if runtime == "dsp":
data_type = "DT_UINT8"
device_type = "HEXAGON" device_type = "HEXAGON"
elif runtime == "gpu": elif runtime == "gpu":
data_type = "DT_HALF"
device_type = "GPU" device_type = "GPU"
elif runtime == "cpu": elif runtime == "cpu":
data_type = "DT_FLOAT"
device_type = "CPU" device_type = "CPU"
return data_type, device_type return device_type
def get_hexagon_mode(configs): def get_hexagon_mode(configs):
runtime_list = [] runtime_list = []
for model_name in configs["models"]: for model_name in configs["models"]:
model_runtime = configs["models"][model_name]["runtime"] model_runtime = configs["models"][model_name].get("runtime", "")
runtime_list.append(model_runtime.lower()) runtime_list.append(model_runtime.lower())
global_runtime = "" global_runtime = ""
...@@ -114,7 +110,7 @@ def model_benchmark_stdout_processor(stdout, ...@@ -114,7 +110,7 @@ def model_benchmark_stdout_processor(stdout,
abi, abi,
serialno, serialno,
model_name, model_name,
runtime): device_type):
metrics = [0] * 3 metrics = [0] * 3
for line in stdout.split('\n'): for line in stdout.split('\n'):
line = line.strip() line = line.strip()
...@@ -138,14 +134,14 @@ def model_benchmark_stdout_processor(stdout, ...@@ -138,14 +134,14 @@ def model_benchmark_stdout_processor(stdout,
f.write("model_name,device_name,soc,abi,runtime," f.write("model_name,device_name,soc,abi,runtime,"
"init,warmup,run_avg\n") "init,warmup,run_avg\n")
data_str = "{model_name},{device_name},{soc},{abi},{runtime}," \ data_str = "{model_name},{device_name},{soc},{abi},{device_type}," \
"{init},{warmup},{run_avg}\n" \ "{init},{warmup},{run_avg}\n" \
.format( .format(
model_name=model_name, model_name=model_name,
device_name=device_name, device_name=device_name,
soc=target_soc, soc=target_soc,
abi=abi, abi=abi,
runtime=runtime, device_type=device_type,
init=metrics[0], init=metrics[0],
warmup=metrics[1], warmup=metrics[1],
run_avg=metrics[2] run_avg=metrics[2]
...@@ -154,8 +150,7 @@ def model_benchmark_stdout_processor(stdout, ...@@ -154,8 +150,7 @@ def model_benchmark_stdout_processor(stdout,
f.write(data_str) f.write(data_str)
def tuning_run(runtime, def tuning_run(target_abi,
target_abi,
serialno, serialno,
vlog_level, vlog_level,
embed_model_data, embed_model_data,
...@@ -205,7 +200,7 @@ def tuning_run(runtime, ...@@ -205,7 +200,7 @@ def tuning_run(runtime,
if running_round > 0 and FLAGS.collect_report: if running_round > 0 and FLAGS.collect_report:
model_benchmark_stdout_processor( model_benchmark_stdout_processor(
stdout, target_abi, serialno, model_name, runtime) stdout, target_abi, serialno, model_name, device_type)
def build_mace_run_prod(hexagon_mode, runtime, target_abi, def build_mace_run_prod(hexagon_mode, runtime, target_abi,
...@@ -222,7 +217,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi, ...@@ -222,7 +217,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi,
strip = "never" strip = "never"
debug = True debug = True
if runtime == "gpu": if not runtime or runtime == "gpu":
gen_opencl_and_tuning_code(target_abi, serialno, [], False) gen_opencl_and_tuning_code(target_abi, serialno, [], False)
sh_commands.bazel_build( sh_commands.bazel_build(
mace_run_target, mace_run_target,
...@@ -234,19 +229,14 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi, ...@@ -234,19 +229,14 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi,
sh_commands.update_mace_run_lib(model_output_dir, sh_commands.update_mace_run_lib(model_output_dir,
model_name, embed_model_data) model_name, embed_model_data)
tuning_run(runtime, target_abi, serialno, vlog_level, embed_model_data, device_type = parse_device_type("gpu")
tuning_run(target_abi, serialno, vlog_level, embed_model_data,
model_output_dir, input_nodes, output_nodes, input_shapes, model_output_dir, input_nodes, output_nodes, input_shapes,
output_shapes, model_name, device_type, running_round=0, output_shapes, model_name, device_type, running_round=0,
restart_round=1, out_of_range_check=False, restart_round=1, out_of_range_check=False,
phone_data_dir=phone_data_dir, tuning=tuning, phone_data_dir=phone_data_dir, tuning=tuning,
limit_opencl_kernel_time=limit_opencl_kernel_time) limit_opencl_kernel_time=limit_opencl_kernel_time)
tuning_run(runtime, target_abi, serialno, vlog_level, embed_model_data,
model_output_dir, input_nodes, output_nodes, input_shapes,
output_shapes, model_name, device_type, running_round=0,
restart_round=1, out_of_range_check=True,
phone_data_dir=phone_data_dir, tuning=False)
gen_opencl_and_tuning_code(target_abi, serialno, [model_output_dir], gen_opencl_and_tuning_code(target_abi, serialno, [model_output_dir],
True) True)
sh_commands.bazel_build( sh_commands.bazel_build(
...@@ -391,8 +381,7 @@ def parse_model_configs(): ...@@ -391,8 +381,7 @@ def parse_model_configs():
print("'platform' must be 'tensorflow' or 'caffe'") print("'platform' must be 'tensorflow' or 'caffe'")
exit(1) exit(1)
for key in ["model_file_path", "model_sha256_checksum", for key in ["model_file_path", "model_sha256_checksum"]:
"runtime"]:
value = model_config.get(key, "") value = model_config.get(key, "")
if value == "": if value == "":
print("CONFIG ERROR:") print("CONFIG ERROR:")
...@@ -529,6 +518,11 @@ def parse_args(): ...@@ -529,6 +518,11 @@ def parse_args():
type=str, type=str,
default="", default="",
help="Valgrind command args.") help="Valgrind command args.")
parser.add_argument(
"--validation_runtime",
type=str,
default="cpu",
help="validation runtime.")
return parser.parse_known_args() return parser.parse_known_args()
...@@ -541,9 +535,11 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -541,9 +535,11 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
print '===================', model_name, '===================' print '===================', model_name, '==================='
model_config = configs["models"][model_name] model_config = configs["models"][model_name]
input_file_list = model_config["validation_inputs_data"] input_file_list = model_config["validation_inputs_data"]
data_type, device_type = get_data_and_device_type( model_runtime = model_config.get("runtime", "")
model_config["runtime"]) model_device_type = parse_device_type(model_runtime)
run_device_type = model_device_type
if not run_device_type:
run_device_type = parse_device_type(FLAGS.validation_runtime)
# Create model build directory # Create model build directory
model_path_digest = md5sum(model_config["model_file_path"]) model_path_digest = md5sum(model_config["model_file_path"])
model_output_base_dir = "%s/%s/%s/%s/%s" % ( model_output_base_dir = "%s/%s/%s/%s/%s" % (
...@@ -581,7 +577,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -581,7 +577,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
if FLAGS.mode == "build" or FLAGS.mode == "all": if FLAGS.mode == "build" or FLAGS.mode == "all":
build_mace_run_prod(hexagon_mode, build_mace_run_prod(hexagon_mode,
model_config["runtime"], model_runtime,
target_abi, target_abi,
serialno, serialno,
vlog_level, vlog_level,
...@@ -592,7 +588,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -592,7 +588,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_config["input_shapes"], model_config["input_shapes"],
model_config["output_shapes"], model_config["output_shapes"],
model_name, model_name,
device_type, model_device_type,
FLAGS.round, FLAGS.round,
FLAGS.restart_round, FLAGS.restart_round,
FLAGS.tuning, FLAGS.tuning,
...@@ -607,8 +603,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -607,8 +603,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
if FLAGS.mode == "run" or FLAGS.mode == "validate" or \ if FLAGS.mode == "run" or FLAGS.mode == "validate" or \
FLAGS.mode == "all": FLAGS.mode == "all":
tuning_run(model_config["runtime"], tuning_run(target_abi,
target_abi,
serialno, serialno,
vlog_level, vlog_level,
embed_model_data, embed_model_data,
...@@ -618,7 +613,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -618,7 +613,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_config["input_shapes"], model_config["input_shapes"],
model_config["output_shapes"], model_config["output_shapes"],
model_name, model_name,
device_type, run_device_type,
FLAGS.round, FLAGS.round,
FLAGS.restart_round, FLAGS.restart_round,
FLAGS.out_of_range_check, FLAGS.out_of_range_check,
...@@ -641,7 +636,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -641,7 +636,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_config["input_shapes"], model_config["input_shapes"],
model_config["output_shapes"], model_config["output_shapes"],
model_name, model_name,
device_type, run_device_type,
phone_data_dir, phone_data_dir,
FLAGS.omp_num_threads, FLAGS.omp_num_threads,
FLAGS.cpu_affinity_policy, FLAGS.cpu_affinity_policy,
...@@ -654,7 +649,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -654,7 +649,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_file_path, model_file_path,
weight_file_path, weight_file_path,
model_config["platform"], model_config["platform"],
model_config["runtime"], run_device_type,
model_config["input_nodes"], model_config["input_nodes"],
model_config["output_nodes"], model_config["output_nodes"],
model_config["input_shapes"], model_config["input_shapes"],
...@@ -746,8 +741,7 @@ def main(unused_args): ...@@ -746,8 +741,7 @@ def main(unused_args):
for model_name in configs["models"]: for model_name in configs["models"]:
print '===================', model_name, '===================' print '===================', model_name, '==================='
model_config = configs["models"][model_name] model_config = configs["models"][model_name]
data_type, device_type = get_data_and_device_type( runtime = model_config.get("runtime", "")
model_config["runtime"])
# Create model build directory # Create model build directory
model_path_digest = md5sum(model_config["model_file_path"]) model_path_digest = md5sum(model_config["model_file_path"])
...@@ -778,8 +772,7 @@ def main(unused_args): ...@@ -778,8 +772,7 @@ def main(unused_args):
model_config["model_sha256_checksum"], model_config["model_sha256_checksum"],
",".join(model_config["input_nodes"]), ",".join(model_config["input_nodes"]),
",".join(model_config["output_nodes"]), ",".join(model_config["output_nodes"]),
data_type, runtime,
model_config["runtime"],
model_name, model_name,
":".join(model_config["input_shapes"]), ":".join(model_config["input_shapes"]),
model_config["dsp_mode"], model_config["dsp_mode"],
......
...@@ -465,7 +465,6 @@ def gen_model_code(model_codegen_dir, ...@@ -465,7 +465,6 @@ def gen_model_code(model_codegen_dir,
model_sha256_checksum, model_sha256_checksum,
input_nodes, input_nodes,
output_nodes, output_nodes,
data_type,
runtime, runtime,
model_tag, model_tag,
input_shapes, input_shapes,
...@@ -489,7 +488,6 @@ def gen_model_code(model_codegen_dir, ...@@ -489,7 +488,6 @@ def gen_model_code(model_codegen_dir,
"--output=%s" % model_codegen_dir + "/model.cc", "--output=%s" % model_codegen_dir + "/model.cc",
"--input_node=%s" % input_nodes, "--input_node=%s" % input_nodes,
"--output_node=%s" % output_nodes, "--output_node=%s" % output_nodes,
"--data_type=%s" % data_type,
"--runtime=%s" % runtime, "--runtime=%s" % runtime,
"--output_type=source", "--output_type=source",
"--template=%s" % "mace/python/tools", "--template=%s" % "mace/python/tools",
...@@ -703,7 +701,7 @@ def validate_model(abi, ...@@ -703,7 +701,7 @@ def validate_model(abi,
model_file_path, model_file_path,
weight_file_path, weight_file_path,
platform, platform,
runtime, device_type,
input_nodes, input_nodes,
output_nodes, output_nodes,
input_shapes, input_shapes,
...@@ -727,7 +725,7 @@ def validate_model(abi, ...@@ -727,7 +725,7 @@ def validate_model(abi,
if platform == "tensorflow": if platform == "tensorflow":
validate(platform, model_file_path, "", validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name), "%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), runtime, "%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes)) ",".join(input_nodes), ",".join(output_nodes))
elif platform == "caffe": elif platform == "caffe":
...@@ -743,7 +741,8 @@ def validate_model(abi, ...@@ -743,7 +741,8 @@ def validate_model(abi,
logger.error('There is no caffe python module.') logger.error('There is no caffe python module.')
validate(platform, model_file_path, weight_file_path, validate(platform, model_file_path, weight_file_path,
"%s/%s" % (model_output_dir, input_file_name), "%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), runtime, "%s/%s" % (model_output_dir, output_file_name),
device_type,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes)) ",".join(input_nodes), ",".join(output_nodes))
elif caffe_env == common.CaffeEnvType.DOCKER: elif caffe_env == common.CaffeEnvType.DOCKER:
...@@ -806,7 +805,7 @@ def validate_model(abi, ...@@ -806,7 +805,7 @@ def validate_model(abi,
"--weight_file=/mace/%s" % weight_file_name, "--weight_file=/mace/%s" % weight_file_name,
"--input_file=/mace/%s" % input_file_name, "--input_file=/mace/%s" % input_file_name,
"--mace_out_file=/mace/%s" % output_file_name, "--mace_out_file=/mace/%s" % output_file_name,
"--mace_runtime=%s" % runtime, "--device_type=%s" % device_type,
"--input_node=%s" % ",".join(input_nodes), "--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
......
...@@ -44,7 +44,7 @@ def load_data(file): ...@@ -44,7 +44,7 @@ def load_data(file):
return np.empty([0]) return np.empty([0])
def compare_output(platform, mace_runtime, output_name, mace_out_value, def compare_output(platform, device_type, output_name, mace_out_value,
out_value): out_value):
if mace_out_value.size != 0: if mace_out_value.size != 0:
out_value = out_value.reshape(-1) out_value = out_value.reshape(-1)
...@@ -53,9 +53,9 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value, ...@@ -53,9 +53,9 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value,
similarity = (1 - spatial.distance.cosine(out_value, mace_out_value)) similarity = (1 - spatial.distance.cosine(out_value, mace_out_value))
print output_name, 'MACE VS', platform.upper( print output_name, 'MACE VS', platform.upper(
), 'similarity: ', similarity ), 'similarity: ', similarity
if (mace_runtime == "cpu" and similarity > 0.999) or \ if (device_type == "CPU" and similarity > 0.999) or \
(mace_runtime == "gpu" and similarity > 0.995) or \ (device_type == "GPU" and similarity > 0.995) or \
(mace_runtime == "dsp" and similarity > 0.930): (device_type == "HEXAGON" and similarity > 0.930):
print '===================Similarity Test Passed==================' print '===================Similarity Test Passed=================='
else: else:
print '===================Similarity Test Failed==================' print '===================Similarity Test Failed=================='
...@@ -65,7 +65,7 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value, ...@@ -65,7 +65,7 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value,
sys.exit(-1) sys.exit(-1)
def validate_tf_model(platform, mace_runtime, model_file, input_file, def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, output_names): mace_out_file, input_names, input_shapes, output_names):
import tensorflow as tf import tensorflow as tf
if not os.path.isfile(model_file): if not os.path.isfile(model_file):
...@@ -100,11 +100,11 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file, ...@@ -100,11 +100,11 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
output_file_name = common.formatted_file_name( output_file_name = common.formatted_file_name(
mace_out_file, output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i], compare_output(platform, device_type, output_names[i],
mace_out_value, output_values[i]) mace_out_value, output_values[i])
def validate_caffe_model(platform, mace_runtime, model_file, input_file, def validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, weight_file, input_names, input_shapes, mace_out_file, weight_file, input_names, input_shapes,
output_names, output_shapes): output_names, output_shapes):
os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints
...@@ -144,12 +144,12 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file, ...@@ -144,12 +144,12 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file,
output_file_name = common.formatted_file_name( output_file_name = common.formatted_file_name(
mace_out_file, output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i], mace_out_value, compare_output(platform, device_type, output_names[i], mace_out_value,
value) value)
def validate(platform, model_file, weight_file, input_file, mace_out_file, def validate(platform, model_file, weight_file, input_file, mace_out_file,
mace_runtime, input_shape, output_shape, input_node, output_node): device_type, input_shape, output_shape, input_node, output_node):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')] input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in shape.split(',')]
...@@ -158,14 +158,14 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file, ...@@ -158,14 +158,14 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
assert len(input_names) == len(input_shapes) assert len(input_names) == len(input_shapes)
if platform == 'tensorflow': if platform == 'tensorflow':
validate_tf_model(platform, mace_runtime, model_file, input_file, validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names) output_names)
elif platform == 'caffe': elif platform == 'caffe':
output_shape_strs = [shape for shape in output_shape.split(':')] output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')] output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs] for shape in output_shape_strs]
validate_caffe_model(platform, mace_runtime, model_file, input_file, validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, weight_file, input_names, mace_out_file, weight_file, input_names,
input_shapes, output_names, output_shapes) input_shapes, output_names, output_shapes)
...@@ -194,7 +194,7 @@ def parse_args(): ...@@ -194,7 +194,7 @@ def parse_args():
default="", default="",
help="mace output file to load.") help="mace output file to load.")
parser.add_argument( parser.add_argument(
"--mace_runtime", type=str, default="gpu", help="mace runtime device.") "--device_type", type=str, default="", help="mace runtime device.")
parser.add_argument( parser.add_argument(
"--input_shape", type=str, default="1,64,64,3", help="input shape.") "--input_shape", type=str, default="1,64,64,3", help="input shape.")
parser.add_argument( parser.add_argument(
...@@ -214,7 +214,7 @@ if __name__ == '__main__': ...@@ -214,7 +214,7 @@ if __name__ == '__main__':
FLAGS.weight_file, FLAGS.weight_file,
FLAGS.input_file, FLAGS.input_file,
FLAGS.mace_out_file, FLAGS.mace_out_file,
FLAGS.mace_runtime, FLAGS.device_type,
FLAGS.input_shape, FLAGS.input_shape,
FLAGS.output_shape, FLAGS.output_shape,
FLAGS.input_node, FLAGS.input_node,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册