提交 c3837858 编写于 作者: 刘琦

Merge branch 'transform' into 'master'

Refactor model converter and transformer

See merge request !477
......@@ -119,11 +119,11 @@ MaceEngine::Impl::Impl(const NetDef *net_def,
LOG(INFO) << "MACE version: " << MaceVersion();
// Set storage path for internal usage
for (auto input_name : input_nodes) {
ws_->CreateTensor(MakeString("mace_input_node_", input_name, ":0"),
ws_->CreateTensor(MakeString("mace_input_node_", input_name),
GetDeviceAllocator(device_type_), DT_FLOAT);
}
for (auto output_name : output_nodes) {
ws_->CreateTensor(MakeString("mace_output_node_", output_name, ":0"),
ws_->CreateTensor(MakeString("mace_output_node_", output_name),
GetDeviceAllocator(device_type_), DT_FLOAT);
}
#ifdef MACE_ENABLE_HEXAGON
......@@ -182,7 +182,7 @@ MaceStatus MaceEngine::Impl::Run(
"The Inputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions");
Tensor *input_tensor =
ws_->GetTensor(MakeString("mace_input_node_", input.first, ":0"));
ws_->GetTensor(MakeString("mace_input_node_", input.first));
input_tensor->Resize(input.second.shape());
{
Tensor::MappingGuard input_guard(input_tensor);
......@@ -199,7 +199,7 @@ MaceStatus MaceEngine::Impl::Run(
" please use 1 to fill missing dimensions");
}
Tensor *output_tensor =
ws_->GetTensor(MakeString("mace_output_node_", output.first + ":0"));
ws_->GetTensor(MakeString("mace_output_node_", output.first));
output_tensors.push_back(output_tensor);
}
#ifdef MACE_ENABLE_HEXAGON
......@@ -223,7 +223,7 @@ MaceStatus MaceEngine::Impl::Run(
#endif
for (auto &output : *outputs) {
Tensor *output_tensor =
ws_->GetTensor(MakeString("mace_output_node_", output.first + ":0"));
ws_->GetTensor(MakeString("mace_output_node_", output.first));
// save output
if (output_tensor != nullptr && output.second.data() != nullptr) {
Tensor::MappingGuard output_guard(output_tensor);
......
......@@ -18,20 +18,20 @@ namespace mace {
namespace ops {
void Register_FullyConnected(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FullyConnected")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
FullyConnectedOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FullyConnected")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
FullyConnectedOp<DeviceType::GPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FullyConnected")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
......
......@@ -37,7 +37,7 @@ void FCBenchmark(
net.AddRandomInput<D, float>("Bias", {out_channel});
if (D == DeviceType::CPU) {
OpDefBuilder("FC", "FullyConnectedTest")
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
......@@ -52,7 +52,7 @@ void FCBenchmark(
BufferToImage<D, T>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
......
......@@ -42,7 +42,7 @@ void Simple(const std::vector<index_t> &input_shape,
if (D == DeviceType::CPU) {
net.Transpose2D<D, float>("Weight", "WeightTranspose");
OpDefBuilder("FC", "FullyConnectedTest")
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
......@@ -59,7 +59,7 @@ void Simple(const std::vector<index_t> &input_shape,
BufferToImage<D, float>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
......@@ -142,7 +142,7 @@ void Complex(const index_t batch,
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::GPU, float>("Bias", {out_channel});
OpDefBuilder("FC", "FullyConnectedTest")
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
......@@ -166,7 +166,7 @@ void Complex(const index_t batch,
BufferToImage<DeviceType::GPU, float>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
......@@ -231,7 +231,7 @@ void TestWXFormat(const index_t batch,
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::GPU, float>("Bias", {out_channel});
OpDefBuilder("FC", "FullyConnectedTest")
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
......@@ -255,7 +255,7 @@ void TestWXFormat(const index_t batch,
BufferToImage<DeviceType::GPU, T>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
......
......@@ -10,6 +10,7 @@ enum NetMode {
enum DeviceType {
CPU = 0; // In default, we will use CPU.
GPU = 2;
HEXAGON = 3;
}
enum DataType {
......
py_library(
name = "tf_converter_lib",
name = "converter_lib",
srcs = [
"convert_util.py",
"graph_util.py",
"tf_converter_lib.py",
"tf_dsp_converter_lib.py",
"converter_tool/base_converter.py",
"converter_tool/shape_inference.py",
"converter_tool/tensorflow_converter.py",
"converter_tool/caffe_converter.py",
"converter_tool/transformer.py",
],
srcs_version = "PY2AND3",
deps = [
":memory_optimizer",
"//mace/proto:mace_py",
],
)
py_library(
name = "caffe_converter_lib",
srcs = [
"caffe_converter_lib.py",
],
srcs_version = "PY2AND3",
deps = [
":memory_optimizer",
"//mace/third_party/caffe:caffe_py",
],
)
......@@ -37,22 +30,21 @@ py_library(
)
py_binary(
name = "converter",
srcs = ["converter.py"],
name = "memory_optimizer",
srcs = ["memory_optimizer.py"],
srcs_version = "PY2AND3",
deps = [
":caffe_converter_lib",
":source_converter_lib",
":tf_converter_lib",
"@six_archive//:six",
"//mace/proto:mace_py",
],
)
py_binary(
name = "memory_optimizer",
srcs = ["memory_optimizer.py"],
name = "converter",
srcs = ["converter.py"],
srcs_version = "PY2AND3",
deps = [
"//mace/proto:mace_py",
":converter_lib",
":source_converter_lib",
"@six_archive//:six",
],
)
此差异已折叠。
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mace.proto import mace_pb2
......@@ -40,3 +41,8 @@ def tf_dtype_2_mace_dtype(tf_dtype):
if not mace_dtype:
raise Exception("Not supported tensorflow dtype: " + tf_dtype)
return mace_dtype
def mace_check(condition, msg):
if not condition:
raise Exception(msg)
......@@ -16,7 +16,16 @@ import argparse
import sys
import hashlib
import os.path
from mace.proto import mace_pb2
from mace.python.tools import tf_dsp_converter_lib
from mace.python.tools import memory_optimizer
from mace.python.tools import source_converter_lib
from mace.python.tools.converter_tool import base_converter as cvt
from mace.python.tools.converter_tool import tensorflow_converter
from mace.python.tools.converter_tool import caffe_converter
from mace.python.tools.converter_tool import transformer
# ./bazel-bin/mace/python/tools/tf_converter --model_file quantized_test.pb \
# --output quantized_test_dsp.pb \
......@@ -25,6 +34,12 @@ from mace.python.tools import source_converter_lib
FLAGS = None
data_type_map = {'DT_HALF': mace_pb2.DT_HALF,
'DT_FLOAT': mace_pb2.DT_FLOAT}
device_type_map = {'cpu': mace_pb2.CPU,
'gpu': mace_pb2.GPU,
'dsp': mace_pb2.HEXAGON}
def file_checksum(fname):
hash_func = hashlib.sha256()
......@@ -34,6 +49,10 @@ def file_checksum(fname):
return hash_func.hexdigest()
def parse_int_array_from_str(ints_str):
return [int(int_str) for int_str in ints_str.split(',')]
def main(unused_args):
if not os.path.isfile(FLAGS.model_file):
print("Input graph file '" + FLAGS.model_file + "' does not exist!")
......@@ -59,27 +78,64 @@ def main(unused_args):
(weight_checksum, FLAGS.weight_checksum))
sys.exit(-1)
if FLAGS.runtime == 'dsp':
print("DSP not support caffe model yet.")
sys.exit(-1)
if FLAGS.platform not in ['tensorflow', 'caffe']:
print ("platform %s is not supported." % FLAGS.platform)
sys.exit(-1)
if FLAGS.runtime not in ['cpu', 'gpu', 'dsp']:
print ("runtime %s is not supported." % FLAGS.runtime)
sys.exit(-1)
from mace.python.tools import caffe_converter_lib
output_graph_def = caffe_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.weight_file, FLAGS.input_node,
FLAGS.input_shape, FLAGS.output_node, FLAGS.data_type,
FLAGS.runtime, FLAGS.winograd)
elif FLAGS.platform == 'tensorflow':
if FLAGS.runtime == 'dsp':
from mace.python.tools import tf_dsp_converter_lib
if FLAGS.runtime == 'dsp':
if FLAGS.platform == 'tensorflow':
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, FLAGS.output_node,
FLAGS.dsp_mode)
else:
from mace.python.tools import tf_converter_lib
output_graph_def = tf_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, FLAGS.input_shape,
FLAGS.output_node, FLAGS.data_type, FLAGS.runtime,
FLAGS.winograd)
print("%s does not support dsp runtime yet." % FLAGS.platform)
sys.exit(-1)
else:
option = cvt.ConverterOption()
option.data_type = data_type_map[FLAGS.data_type]
option.device = device_type_map[FLAGS.runtime]
option.winograd_enabled = bool(FLAGS.winograd)
input_node_names = FLAGS.input_node.split(',')
input_node_shapes = FLAGS.input_shape.split(':')
if len(input_node_names) != len(input_node_shapes):
raise Exception('input node count and shape count do not match.')
for i in xrange(len(input_node_names)):
input_node = cvt.NodeInfo()
input_node.name = input_node_names[i]
input_node.shape = parse_int_array_from_str(FLAGS.input_shape)
option.add_input_node(input_node)
output_node_names = FLAGS.output_node.split(',')
for i in xrange(len(output_node_names)):
output_node = cvt.NodeInfo()
output_node.name = output_node_names[i]
option.add_output_node(output_node)
print("Convert model to mace model.")
if FLAGS.platform == 'tensorflow':
converter = tensorflow_converter.TensorflowConverter(option,
FLAGS.model_file) # noqa
elif FLAGS.platform == 'caffe':
converter = caffe_converter.CaffeConverter(option,
FLAGS.model_file,
FLAGS.weight_file)
output_graph_def = converter.run()
print("Transform model to one that can better run on device.")
# TODO(liuqi/liyin): transform gpu/cpu and merge their ops
mace_transformer = transformer.Transformer(option, output_graph_def)
output_graph_def = mace_transformer.run()
print "start optimize memory."
if FLAGS.runtime == 'gpu':
memory_optimizer.optimize_gpu_memory(output_graph_def)
elif FLAGS.runtime == 'cpu':
memory_optimizer.optimize_cpu_memory(output_graph_def)
print "Memory optimization done."
if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(
......
from enum import Enum
from mace.proto import mace_pb2
class DataFormat(Enum):
NHWC = 0
NCHW = 1
class FilterFormat(Enum):
HWIO = 0
OIHW = 1
HWOI = 2
class PaddingMode(Enum):
VALID = 0
SAME = 1
FULL = 2
class PoolingType(Enum):
AVG = 1
MAX = 2
class ActivationType(Enum):
NOOP = 0
RELU = 1
RELUX = 2
PRELU = 3
TANH = 4
SIGMOID = 5
class EltwiseType(Enum):
SUM = 0
SUB = 1
PROD = 2
DIV = 3
MIN = 4
MAX = 5
NEG = 6
ABS = 7
SQR_DIFF = 8
POW = 9
MaceSupportedOps = [
'Activation',
'AddN',
'BatchNorm',
'BatchToSpaceND',
'BiasAdd',
'ChannelShuffle',
'Concat',
'Conv2D',
'Deconv2D',
'DepthToSpace',
'DepthwiseConv2d',
'Dequantize',
'Eltwise',
'FoldedBatchNorm',
'FullyConnected',
'LocalResponseNorm',
'MatMul',
'Pad',
'Pooling',
'Proposal',
'PSROIAlign',
'Quantize',
'Requantize',
'Reshape',
'ResizeBilinear',
'Slice',
'Softmax',
'SpaceToBatchND',
'SpaceToDepth',
'Transpose',
'WinogradInverseTransform',
'WinogradTransform',
]
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
class MaceKeyword(object):
# node related str
mace_input_node_name = 'mace_input_node'
mace_output_node_name = 'mace_output_node'
mace_buffer_type = 'buffer_type'
mace_mode = 'mode'
mace_buffer_to_image = 'BufferToImage'
mace_image_to_buffer = 'ImageToBuffer'
# arg related str
mace_padding_str = 'padding'
mace_padding_values_str = 'padding_values'
mace_strides_str = 'strides'
mace_dilations_str = 'dilations'
mace_pooling_type_str = 'pooling_type'
mace_global_pooling_str = 'global_pooling'
mace_kernel_str = 'kernels'
mace_data_format_str = 'data_format'
mace_filter_format_str = 'filter_format'
mace_element_type_str = 'type'
mace_activation_type_str = 'activation'
mace_activation_max_limit_str = 'max_limit'
mace_resize_size_str = 'size'
mace_batch_to_space_crops_str = 'crops'
mace_paddings_str = 'paddings'
mace_align_corners_str = 'align_corners'
mace_space_batch_block_shape_str = 'block_shape'
mace_space_depth_block_size_str = 'block_size'
mace_constant_value_str = 'constant_value'
mace_dims_str = 'dims'
mace_axis_str = 'axis'
mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed'
class ConverterInterface(object):
"""Base class for converting external models to mace models."""
def run(self):
raise NotImplementedError('run')
class NodeInfo(object):
"""A class for describing node information"""
def __init__(self):
self._name = None
self._shape = []
@property
def name(self):
return self._name
@property
def shape(self):
return self._shape
@name.setter
def name(self, name):
self._name = name
@shape.setter
def shape(self, shape):
self._shape = shape
def __str__(self):
return '%s %s' % (self._name, str(self._shape))
class ConverterOption(object):
"""A class for specifying options passed to converter tool"""
def __init__(self):
self._input_nodes = {}
self._output_nodes = {}
self._data_type = mace_pb2.DT_FLOAT
self._device = mace_pb2.CPU
self._winograd_enabled = False
@property
def input_nodes(self):
return self._input_nodes
@property
def output_nodes(self):
return self._output_nodes
@property
def data_type(self):
return self._data_type
@property
def device(self):
return self._device
@property
def winograd_enabled(self):
return self._winograd_enabled
@input_nodes.setter
def input_nodes(self, input_nodes):
for node in input_nodes:
self._input_nodes[node.name] = node
def add_input_node(self, input_node):
self._input_nodes[input_node.name] = input_node
@output_nodes.setter
def output_nodes(self, output_nodes):
for node in output_nodes:
self.output_nodes[node.name] = node
def add_output_node(self, output_node):
self._output_nodes[output_node.name] = output_node
@data_type.setter
def data_type(self, data_type):
self._data_type = data_type
@device.setter
def device(self, device):
self._device = device
@winograd_enabled.setter
def winograd_enabled(self, winograd_enabled):
self._winograd_enabled = winograd_enabled
class ConverterUtil(object):
@staticmethod
def get_arg(op, arg_name):
for arg in op.arg:
if arg.name == arg_name:
return arg
return None
@staticmethod
def add_data_format_arg(op, data_format):
data_format_arg = op.arg.add()
data_format_arg.name = MaceKeyword.mace_data_format_str
data_format_arg.i = data_format.value
@staticmethod
def data_format(op):
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str)
if arg is None:
return None
elif arg.i == DataFormat.NHWC.value:
return DataFormat.NHWC
elif arg.i == DataFormat.NCHW.value:
return DataFormat.NCHW
else:
return None
@staticmethod
def set_filter_format(net, filter_format):
arg = net.arg.add()
arg.name = MaceKeyword.mace_filter_format_str
arg.i = filter_format.value
@staticmethod
def filter_format(net):
arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str)
if arg is None:
return None
elif arg.i == FilterFormat.HWIO.value:
return FilterFormat.HWIO
elif arg.i == FilterFormat.HWOI.value:
return FilterFormat.HWOI
elif arg.i == FilterFormat.OIHW.value:
return FilterFormat.OIHW
else:
return None
import math
import numpy as np
import google.protobuf.text_format
from mace.proto import mace_pb2
from mace.third_party.caffe import caffe_pb2
from mace.python.tools.converter_tool import base_converter
from mace.python.tools.converter_tool import shape_inference
from mace.python.tools.converter_tool.base_converter import PoolingType
from mace.python.tools.converter_tool.base_converter import ActivationType
from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.convert_util import mace_check
caffe_group_str = 'group'
caffe_kernel_h_str = 'kernel_h'
caffe_kernel_w_str = 'kernel_w'
caffe_stride_h_str = 'stride_h'
caffe_stride_w_str = 'stride_w'
caffe_pad_h_str = 'pad_h'
caffe_pad_w_str = 'pad_w'
class CaffeOperator(object):
"""CaffeOperator merges and provides both layer and weights information.
Layer records caffe layer proto, while blobs records the weight data in
format of numpy ndarray.
"""
def __init__(self):
self._layer = None
self._blobs = None
@property
def name(self):
return self._layer.name
@property
def type(self):
return self._layer.type
@property
def layer(self):
return self._layer
@property
def blobs(self):
return self._blobs
@layer.setter
def layer(self, layer):
self._layer = layer
@blobs.setter
def blobs(self, blobs):
self._blobs = [self.blob_to_nparray(blob) for blob in blobs]
def get_blob(self, index):
mace_check(index < len(self._blobs), "blob out of index")
return self._blobs[index]
@staticmethod
def blob_to_nparray(blob):
if blob.num != 0:
return (np.asarray(blob.data, dtype=np.float32).reshape(
(blob.num, blob.channels, blob.height, blob.width)))
else:
return np.asarray(blob.data, dtype=np.float32).reshape(
blob.shape.dim)
class CaffeNet(object):
"""CaffeNet contains caffe operations. Output of each layer has unique
name as we replace duplicated output name with unique one, while keep
mace input/output name which user specifies unchanged."""
def __init__(self):
self._ops = {}
self._consumers = {}
# for in-place op, its input name is the same with output name,
# so we change the output name to an alias
self._alias_op_output_name = {}
self._used_op_output_name = set()
@property
def ops(self):
return self._ops.values()
def get_op(self, op_name):
return self._ops.get(op_name, None)
def get_consumers(self, tensor_name):
return self._consumers.get(tensor_name, [])
def add_layer(self, layer):
op = CaffeOperator()
op.layer = layer
self._ops[layer.name] = op
# change op output name if it is an in-place op
layer.bottom[:] = [self._alias_op_output_name.get(layer_input,
layer_input) for
layer_input in layer.bottom][:]
for i in xrange(len(layer.top)):
old_name = layer.top[i]
if layer.type == 'Input':
new_name = old_name
else:
idx = 0
new_name = old_name + '#' + str(idx)
while new_name in self._used_op_output_name:
idx += 1
new_name = old_name + '#' + str(idx)
layer.top[i] = new_name
self._alias_op_output_name[old_name] = new_name
self._used_op_output_name.update([new_name])
for input_tensor in layer.bottom:
if input_tensor not in self._consumers:
self._consumers[input_tensor] = []
self._consumers[input_tensor].append(op)
def add_blob(self, weight):
if weight.name in self._ops:
op = self._ops[weight.name]
op.blobs = list(weight.blobs)
class CaffeConverter(base_converter.ConverterInterface):
"""A class for convert caffe model to mace model."""
pooling_type_mode = {
caffe_pb2.PoolingParameter.AVE: PoolingType.AVG,
caffe_pb2.PoolingParameter.MAX: PoolingType.MAX
}
eltwise_type = {
caffe_pb2.EltwiseParameter.PROD: EltwiseType.PROD,
caffe_pb2.EltwiseParameter.SUM: EltwiseType.SUM,
caffe_pb2.EltwiseParameter.MAX: EltwiseType.MAX,
}
activation_type = {
'ReLU': ActivationType.RELU,
'PReLU': ActivationType.PRELU,
'TanH': ActivationType.TANH,
}
def __init__(self, option, src_model_file, src_weight_file):
self._op_converters = {
'Input': self.convert_nop,
'Convolution': self.convert_conv2d,
'Eltwise': self.convert_elementwise,
'Add': self.convert_add,
'ReLU': self.convert_activation,
'TanH': self.convert_activation,
'Sigmoid': self.convert_activation,
'PReLU': self.convert_activation,
'Pooling': self.convert_pooling,
'Concat': self.convert_concat,
'Slice': self.convert_slice,
'Softmax': self.convert_softmax,
'InnerProduct': self.convert_fully_connected,
'BatchNorm': self.convert_folded_batchnorm,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.OIHW)
self._caffe_net = CaffeNet()
self._caffe_layers = caffe_pb2.NetParameter()
caffe_weights = caffe_pb2.NetParameter()
# parse prototxt
with open(src_model_file, 'rb') as f:
google.protobuf.text_format.Merge(
str(f.read()), self._caffe_layers)
self.filter_test_layers(self._caffe_layers)
for layer in self._caffe_layers.layer:
self._caffe_net.add_layer(layer)
# parse model weight
with open(src_weight_file, 'rb') as f:
caffe_weights.ParseFromString(f.read())
self.filter_test_layers(caffe_weights)
for weight in caffe_weights.layer:
self._caffe_net.add_blob(weight)
self._skip_ops = []
def run(self):
self.convert_ops()
shape_inferer = shape_inference.ShapeInference(
self._mace_net_def,
self._option.input_nodes.values())
shape_inferer.run()
self.replace_output_tensor_name()
return self._mace_net_def
@staticmethod
def replace_input_name(ops, src_name, dst_name):
for op in ops:
for i in xrange(len(op.input)):
if op.input[i] == src_name:
op.input[i] = dst_name
def replace_output_tensor_name(self):
consumers = {}
for op in self._mace_net_def.op:
for input_name in op.input:
if input_name not in consumers:
consumers[input_name] = []
consumers[input_name].append(op)
# replace the last op with same prefix name with the original top name
ops = [op for op in self._mace_net_def.op]
ops.reverse()
visited = set()
for op in ops:
for i in xrange(len(op.output)):
original_output_name = op.output[i].split('#')[0]
if original_output_name not in visited:
self.replace_input_name(
consumers.get(op.output[i], []),
op.output[i],
original_output_name)
op.output[i] = original_output_name
visited.update([original_output_name])
# if user set op name as output node, replace it with op name
for op in self._mace_net_def.op:
if op.name in self._option.output_nodes:
if len(op.output) > 0:
self.replace_input_name(
consumers.get(op.output[0], []),
op.output,
op.name)
op.output[0] = op.name
@staticmethod
def filter_test_layers(layers):
phase_map = {0: 'train', 1: 'test'}
while True:
changed = False
for layer in layers.layer:
phase = 'test'
if len(layer.include):
phase = phase_map[layer.include[0].phase]
if len(layer.exclude):
phase = phase_map[layer.exclude[0].phase]
if phase != 'test' or layer.type == 'Dropout':
print ("Remove layer %s (%s)" % (layer.name, layer.type))
layers.layer.remove(layer)
changed = True
break
if not changed:
break
@staticmethod
def add_stride_pad_kernel_arg(param, op_def):
try:
if len(param.stride) > 1 or len(param.kernel_size) > 1 or len(
param.pad) > 1:
raise Exception(
'Mace does not support multiple stride/kernel_size/pad')
stride = [param.stride[0],
param.stride[0]] if len(param.stride) else [1, 1]
pad = [param.pad[0] * 2,
param.pad[0] * 2] if len(param.pad) else [0, 0]
kernel = [param.kernel_size[0], param.kernel_size[0]] if len(
param.kernel_size) else [0, 0]
except TypeError:
stride = [param.stride, param.stride]
pad = [param.pad * 2, param.pad * 2]
kernel = [param.kernel_size, param.kernel_size]
if param.HasField(caffe_stride_h_str) or param.HasField(
caffe_stride_w_str):
stride = [param.stride_h, param.stride_w]
if param.HasField(caffe_pad_h_str) or param.HasField(caffe_pad_w_str):
pad = [param.pad_h * 2, param.pad_w * 2]
strides_arg = op_def.arg.add()
strides_arg.name = MaceKeyword.mace_strides_str
strides_arg.ints.extend(stride)
padding_arg = op_def.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str
padding_arg.ints.extend(pad)
if op_def.type == MaceOp.Pooling.name:
if param.HasField(caffe_kernel_h_str) or param.HasField(
caffe_kernel_w_str):
kernel = [param.kernel_h, param.kernel_w]
kernels_arg = op_def.arg.add()
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(kernel)
if param.HasField('global_pooling'):
global_pooling_arg = op_def.arg.add()
global_pooling_arg.name = MaceKeyword.mace_global_pooling_str
global_pooling_arg.i = 1
def convert_ops(self):
for layer in self._caffe_layers.layer:
caffe_op = self._caffe_net.get_op(layer.name)
if caffe_op not in self._skip_ops:
mace_check(layer.type in self._op_converters,
"Mace does not support caffe op type %s yet"
% layer.type)
self._op_converters[layer.type](caffe_op)
def add_tensor(self, name, shape, data_type, value):
tensor = self._mace_net_def.tensors.add()
tensor.name = name
tensor.dims.extend(list(shape))
tensor.data_type = data_type
tensor.float_data.extend(value.flat)
def convert_nop(self, layer):
pass
def convert_general_op(self, caffe_op):
op = self._mace_net_def.op.add()
op.name = caffe_op.name
op.type = caffe_op.type
op.input.extend(caffe_op.layer.bottom)
op.output.extend(caffe_op.layer.top)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
return op
def convert_conv2d(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.convolution_param
is_depthwise = False
if param.HasField(caffe_group_str):
mace_check(param.group == caffe_op.blob[0].shape[1] and
caffe_op.blob[0].shape[0] == 1,
"Mace do not support group convolution yet")
is_depthwise = True
if is_depthwise:
op.type = MaceOp.DepthwiseConv2d.name
else:
op.type = MaceOp.Conv2D.name
self.add_stride_pad_kernel_arg(param, op)
# dilation is specific for convolution in caffe
dilations = [1, 1]
if len(param.dilation) > 0:
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
if len(param.dilation) == 1:
dilations = [param.dilation[0], param.dilation[0]]
elif len(param.dilation) == 2:
dilations = [param.dilation[0], param.dilation[1]]
dilation_arg.ints.extend(dilations)
filter_tensor_name = op.name + '_filter'
filter_data = caffe_op.blobs[0]
self.add_tensor(filter_tensor_name, filter_data.shape,
mace_pb2.DT_FLOAT, filter_data)
op.input.extend([filter_tensor_name])
if len(caffe_op.blobs) == 2:
bias_tensor_name = op.name + '_bias'
bias_data = caffe_op.blobs[1]
self.add_tensor(bias_tensor_name, bias_data.shape,
mace_pb2.DT_FLOAT,
bias_data)
op.input.extend([bias_tensor_name])
def convert_elementwise(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.eltwise_param
op.type = MaceOp.Eltwise.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_element_type_str
type_arg.i = self.eltwise_type[param.operation].value
if len(param.coeff) > 0:
coeff_arg = op.arg.add()
coeff_arg.name = 'coeff'
coeff_arg.floats.extend(list(param.coeff))
def convert_add(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.AddN.name
def convert_activation(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.Activation.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_activation_type_str
type_arg.s = self.activation_type[caffe_op.type].name
if caffe_op.type == 'PReLU':
alpha_tensor_name = caffe_op.name + '_alpha'
alpha_data = caffe_op.blobs[0]
self.add_tensor(alpha_tensor_name, alpha_data.shape,
mace_pb2.DT_FLOAT, alpha_data)
op.input.extend([alpha_tensor_name])
def convert_folded_batchnorm(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.FoldedBatchNorm.name
scale_op = None
for consumer in self._caffe_net.get_consumers(caffe_op.layer.top[0]):
if consumer.type == 'Scale':
scale_op = consumer
mace_check(scale_op is not None, "batchnorm is not followed by scale")
self._skip_ops.append(scale_op)
epsilon_value = caffe_op.layer.batch_norm_param.eps
mace_check(caffe_op.blobs[2][0] != 0, "batchnorm scalar is zero")
mean_value = (1. / caffe_op.blobs[2][0]) * caffe_op.blobs[0]
var_value = (1. / caffe_op.blobs[2][0]) * caffe_op.blobs[1]
gamma_value = scale_op.blobs[0]
beta_value = np.zeros_like(mean_value)
if len(scale_op.blobs) == 2:
beta_value = scale_op.blobs[1]
scale_value = (
(1.0 / np.vectorize(math.sqrt)(var_value + epsilon_value)) *
gamma_value).reshape(-1)
offset_value = ((-mean_value * scale_value) + beta_value).reshape(-1)
input_names = [op.name + '_scale', op.name + '_offset']
self.add_tensor(input_names[0], scale_value.shape, mace_pb2.DT_FLOAT,
scale_value)
self.add_tensor(input_names[1], offset_value.shape, mace_pb2.DT_FLOAT,
offset_value)
op.input.extend([name for name in input_names])
op.output[:] = scale_op.layer.top[:]
def convert_pooling(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.pooling_param
op.type = MaceOp.Pooling.name
self.add_stride_pad_kernel_arg(param, op)
pooling_type_arg = op.arg.add()
pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
pooling_type_arg.i = self.pooling_type_mode[param.pool].value
def convert_softmax(self, caffe_op):
self.convert_general_op(caffe_op)
def convert_concat(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.concat_param
op.type = MaceOp.Concat.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 1
if param.HasField('axis'):
axis_arg.i = param.axis
elif param.HasField('concat_dim'):
axis_arg.i = param.concat_dim
mace_check(axis_arg.i == 1, "only support concat at channel dimension")
def convert_slice(self, caffe_op):
op = self.convert_general_op(caffe_op)
op.type = MaceOp.Slice.name
if caffe_op.layer.HasField('slice_param'):
param = caffe_op.layer.slice_param
mace_check(not param.HasField('axis') or param.axis == 1,
"Mace do not support slice with axis %d" % param.axis)
mace_check(len(param.slice_point) == 0,
"Mace do not support slice with slice_point")
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 1
def convert_fully_connected(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.inner_product_param
op.type = MaceOp.FullyConnected.name
mace_check(param.axis == 1 and not param.transpose,
"Do not support non-default axis and transpose")
mace_check(caffe_op.blobs[0].ndim in [2, 4],
"Unexpected fc weigth ndim.")
if caffe_op.blobs[0].ndim == 4:
mace_check(list(caffe_op.blobs[0].shape[:2]) == [1, 1],
"Do not support 4D weight with shape [1, 1, *, *]")
weight_tensor_name = op.name + '_weight'
weight_data = caffe_op.blobs[0].reshape(param.num_output, -1)
self.add_tensor(weight_tensor_name, weight_data.shape,
mace_pb2.DT_FLOAT,
weight_data)
op.input.extend([weight_tensor_name])
if len(caffe_op.blobs) == 2:
bias_tensor_name = op.name + '_bias'
bias_data = caffe_op.blobs[1]
self.add_tensor(bias_tensor_name, bias_data.shape,
mace_pb2.DT_FLOAT,
bias_data)
op.input.extend([bias_tensor_name])
import math
import numpy as np
from mace.python.tools.converter_tool.transformer import Transformer
from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.convert_util import mace_check
class ShapeInference(object):
"""Currently we only use it to infer caffe shape, we use tensorflow engine
to infer tensorflow op shapes, since tensorflow has too many ops."""
def __init__(self, net, input_nodes):
self._op_shape_inference = {
MaceOp.Conv2D.name: self.infer_shape_conv_pool_shape,
MaceOp.Eltwise.name: self.infer_shape_general,
MaceOp.FoldedBatchNorm.name: self.infer_shape_general,
MaceOp.AddN.name: self.infer_shape_general,
MaceOp.Activation.name: self.infer_shape_general,
MaceOp.Pooling.name: self.infer_shape_conv_pool_shape,
MaceOp.Concat.name: self.infer_shape_concat,
MaceOp.Slice.name: self.infer_shape_slice,
MaceOp.Softmax.name: self.infer_shape_general,
MaceOp.FullyConnected.name: self.infer_shape_fully_connected,
}
self._net = net
self._output_shape_cache = {}
for input_node in input_nodes:
input_shape = input_node.shape[:]
# transpose input from NCHW to NHWC
Transformer.transpose_shape(input_shape, [0, 3, 1, 2])
self._output_shape_cache[input_node.name] = input_shape
for tensor in net.tensors:
self._output_shape_cache[tensor.name] = list(tensor.dims)
def run(self):
for op in self._net.op:
mace_check(op.type in self._op_shape_inference,
"Mace does not support caffe op type %s yet"
% op.type)
self._op_shape_inference[op.type](op)
def add_output_shape(self, op, shapes):
mace_check(len(op.output) == len(shapes),
"Op %s (%s) output count is different from "
"output shape count" % (
op.name, op.type))
for i in xrange(len(shapes)):
output_name = op.output[i]
output_shape = op.output_shape.add()
output_shape.dims.extend(shapes[i])
self._output_shape_cache[output_name] = shapes[i]
def infer_shape_general(self, op):
if len(op.input) > 0:
mace_check(op.input[0] in self._output_shape_cache,
"%s does not exist" % op.input[0])
input_shape = self._output_shape_cache[op.input[0]]
self.add_output_shape(op, [input_shape])
def infer_shape_conv_pool_shape(self, op):
input_shape = self._output_shape_cache[op.input[0]]
output_shape = np.zeros_like(input_shape)
if op.type == MaceOp.Pooling:
filter_shape = list(
ConverterUtil.get_arg(op, MaceKeyword.mace_kernel_str).ints)
if ConverterUtil.data_format(op) == DataFormat.NCHW:
filter_shape = [input_shape[1], input_shape[1]] + filter_shape
if ConverterUtil.get_arg(op,
MaceKeyword.mace_global_pooling_str) \
is not None:
filter_shape[2] = input_shape[2]
filter_shape[3] = input_shape[3]
else: # NHWC
filter_shape = filter_shape + [input_shape[1], input_shape[1]]
if ConverterUtil.get_arg(op,
MaceKeyword.mace_global_pooling_str) \
is not None:
filter_shape[0] = input_shape[1]
filter_shape[1] = input_shape[2]
else:
filter_shape = self._output_shape_cache[op.input[1]]
paddings = ConverterUtil.get_arg(op,
MaceKeyword.mace_padding_values_str).ints # noqa
strides = ConverterUtil.get_arg(op, MaceKeyword.mace_strides_str).ints
dilations_arg = ConverterUtil.get_arg(op,
MaceKeyword.mace_dilations_str)
if dilations_arg is not None:
dilations = dilations_arg.ints
else:
dilations = [1, 1]
if op.type == MaceOp.Pooling:
round_func = math.ceil
else:
round_func = math.floor
output_shape[0] = input_shape[0]
if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa
# filter format: OIHW
output_shape[1] = filter_shape[0]
output_shape[2] = int(
round_func((input_shape[2] + paddings[0] - filter_shape[2] -
(filter_shape[2] - 1) *
(dilations[0] - 1)) / float(strides[0]))) + 1
output_shape[3] = int(
round_func((input_shape[3] + paddings[1] - filter_shape[3] -
(filter_shape[3] - 1) *
(dilations[1] - 1)) / float(strides[1]))) + 1
else:
mace_check(False,
"Mace can only infer shape for"
" NCHW input and OIHW filter")
self.add_output_shape(op, [output_shape])
def infer_shape_concat(self, op):
output_shape = self._output_shape_cache[op.input[0]]
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
for input_node in op.input:
input_shape = self._output_shape_cache[input_node]
output_shape[axis] += input_shape[axis]
self.add_output_shape(op, [output_shape])
def infer_shape_slice(self, op):
output_shape = self._output_shape_cache[op.input[0]]
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
output_shape[axis] /= len(op.output)
output_shapes = []
for _ in op.output:
output_shapes.append(output_shape)
self.add_output_shape(op, output_shapes)
def infer_shape_fully_connected(self, op):
input_shape = self._output_shape_cache[op.input[0]]
weight_shape = self._output_shape_cache[op.input[1]]
if ConverterUtil.data_format(op) == DataFormat.NCHW:
output_shape = [input_shape[0], weight_shape[0], 1, 1]
else:
mace_check(False, "format %s is not supported"
% ConverterUtil.data_format(op))
self.add_output_shape(op, [output_shape])
import math
import numpy as np
import tensorflow as tf
from mace.proto import mace_pb2
from mace.python.tools.converter_tool import base_converter
from mace.python.tools.converter_tool.base_converter import PoolingType
from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import ActivationType
from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.convert_util import mace_check
from tensorflow.core.framework import tensor_shape_pb2
tf_padding_str = 'padding'
tf_strides_str = 'strides'
tf_dilations_str = 'dilations'
tf_data_format_str = 'data_format'
tf_kernel_str = 'ksize'
tf_epsilon_str = 'epsilon'
tf_align_corners = 'align_corners'
tf_block_size = 'block_size'
class TensorflowConverter(base_converter.ConverterInterface):
"""A class for convert tensorflow frozen model to mace model.
We use tensorflow engine to infer op output shapes, since they are of
too many types."""
padding_mode = {
'VALID': PaddingMode.VALID,
'SAME': PaddingMode.SAME,
'FULL': PaddingMode.FULL
}
pooling_type_mode = {
'AvgPool': PoolingType.AVG,
'MaxPool': PoolingType.MAX
}
eltwise_type = {
'Add': EltwiseType.SUM,
'Sub': EltwiseType.SUB,
'Mul': EltwiseType.PROD,
'Div': EltwiseType.DIV,
'Min': EltwiseType.MIN,
'Max': EltwiseType.MAX,
'Neg': EltwiseType.NEG,
'Abs': EltwiseType.ABS,
'RealDiv': EltwiseType.DIV,
'SquaredDifference': EltwiseType.SQR_DIFF,
'Pow': EltwiseType.POW
}
activation_type = {
'Relu': ActivationType.RELU,
'Relu6': ActivationType.RELUX,
'Tanh': ActivationType.TANH,
'Sigmoid': ActivationType.SIGMOID
}
def __init__(self, option, src_model_file):
self._op_converters = {
'Conv2D': self.convert_conv2d,
'DepthwiseConv2dNative': self.convert_conv2d,
'Conv2DBackpropInput': self.convert_conv2d,
'BiasAdd': self.convert_biasadd,
'Add': self.convert_add,
'Sub': self.convert_elementwise,
'Mul': self.convert_elementwise,
'Div': self.convert_elementwise,
'Min': self.convert_elementwise,
'Max': self.convert_elementwise,
'Neg': self.convert_elementwise,
'Abs': self.convert_elementwise,
'RealDiv': self.convert_elementwise,
'SquaredDifference': self.convert_elementwise,
'Pow': self.convert_elementwise,
'Relu': self.convert_activation,
'Relu6': self.convert_activation,
'Tanh': self.convert_activation,
'Sigmoid': self.convert_activation,
'FusedBatchNorm': self.convert_fused_batchnorm,
'AvgPool': self.convert_pooling,
'MaxPool': self.convert_pooling,
'Squeeze': self.convert_identity,
'Reshape': self.convert_reshape,
'Shape': self.convert_nop,
'Softmax': self.convert_softmax,
'ResizeBilinear': self.convert_resize_bilinear,
'Placeholder': self.convert_nop,
'SpaceToBatchND': self.convert_space_batch,
'BatchToSpaceND': self.convert_space_batch,
'DepthToSpace': self.convert_space_depth,
'SpaceToDepth': self.convert_space_depth,
'Pad': self.convert_pad,
'ConcatV2': self.convert_concat,
'Mean': self.convert_mean,
# Const converter_tool should be placed at the end
'Const': self.convert_tensor,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.HWIO)
tf_graph_def = tf.GraphDef()
with tf.gfile.Open(src_model_file, 'rb') as f:
tf_graph_def.ParseFromString(f.read())
self.add_shape_info(tf_graph_def)
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(tf_graph_def, name='')
self._tf_graph = graph
self._skip_tensor = set()
def run(self):
with tf.Session() as session:
self.convert_ops()
self.replace_input_output_tensor_name()
return self._mace_net_def
def replace_input_output_tensor_name(self):
for op in self._mace_net_def.op:
for i in xrange(len(op.input)):
if op.input[i][-2:] == ':0':
op_name = op.input[i][:-2]
if op_name in self._option.input_nodes:
op.input[i] = op_name
for i in xrange(len(op.output)):
if op.output[i][-2:] == ':0':
op_name = op.output[i][:-2]
if op_name in self._option.output_nodes:
op.output[i] = op_name
def add_shape_info(self, tf_graph_def):
for node in tf_graph_def.node:
if node.name in self._option.input_nodes:
del node.attr['shape'].shape.dim[:]
node.attr['shape'].shape.dim.extend([
tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
self._option.input_nodes[node.name].shape
])
@staticmethod
def get_scope(tensor_name):
idx = tensor_name.rfind('/')
if idx == -1:
return tensor_name
else:
return tensor_name[:idx]
def convert_ops(self):
for tf_op in self._tf_graph.get_operations():
mace_check(tf_op.type in self._op_converters,
"Mace does not support tensorflow op type %s yet"
% tf_op.type)
self._op_converters[tf_op.type](tf_op)
def convert_tensor(self, tf_op):
output_name = tf_op.outputs[0].name
if output_name not in self._skip_tensor:
tensor = self._mace_net_def.tensors.add()
tensor.name = tf_op.outputs[0].name
tf_tensor = tf_op.outputs[0].eval()
tensor.dims.extend(list(tf_tensor.shape))
tf_dt = tf_op.get_attr('dtype')
if tf_dt == tf.float32:
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend(tf_tensor.astype(np.float32).flat)
elif tf_dt == tf.int32:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(tf_tensor.astype(np.int32).flat)
else:
mace_check(False, "Not supported tensor type: %s" % tf_dt.name)
def add_tensor(self, name, shape, data_type, value):
tensor = self._mace_net_def.tensors.add()
tensor.name = name
tensor.dims.extend(list(shape))
tensor.data_type = data_type
tensor.float_data.extend(value.flat)
def convert_nop(self, tf_op):
pass
def convert_general_op(self, tf_op):
op = self._mace_net_def.op.add()
op.name = tf_op.name
op.type = tf_op.type
op.input.extend([tf_input.name for tf_input in tf_op.inputs])
op.output.extend([tf_output.name for tf_output in tf_op.outputs])
for tf_output in tf_op.outputs:
output_shape = op.output_shape.add()
output_shape.dims.extend(tf_output.shape.as_list())
op.output_type.append(self._option.data_type)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
return op
def convert_identity(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = 'Identity'
def convert_conv2d(self, tf_op):
op = self.convert_general_op(tf_op)
if tf_op.type == 'DepthwiseConv2dNative':
op.type = MaceOp.DepthwiseConv2d.name
elif tf_op.type == 'Conv2DBackpropInput':
op.type = MaceOp.Deconv2D.name
else:
op.type = MaceOp.Conv2D.name
padding_arg = op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_str
padding_arg.i = self.padding_mode[tf_op.get_attr(tf_padding_str)].value
strides_arg = op.arg.add()
strides_arg.name = MaceKeyword.mace_strides_str
strides_arg.ints.extend(tf_op.get_attr(tf_strides_str)[1:3])
if op.type != MaceOp.Deconv2D.name:
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
dilation_arg.ints.extend(tf_op.get_attr(tf_dilations_str)[1:3])
def convert_elementwise(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Eltwise.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_element_type_str
type_arg.i = self.eltwise_type[tf_op.type].value
def convert_biasadd(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.BiasAdd.name
def convert_add(self, tf_op):
if len(tf_op.inputs) == 2:
self.convert_elementwise(tf_op)
else:
op = self.convert_general_op(tf_op)
op.type = MaceOp.AddN.name
def convert_activation(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Activation.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_activation_type_str
type_arg.s = self.activation_type[tf_op.type].name
if tf_op.type == 'Relu6':
limit_arg = op.arg.add()
limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = 6.0
def convert_fused_batchnorm(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.FoldedBatchNorm.name
gamma_value = tf_op.inputs[1].eval().astype(np.float32)
beta_value = tf_op.inputs[2].eval().astype(np.float32)
mean_value = tf_op.inputs[3].eval().astype(np.float32)
var_value = tf_op.inputs[4].eval().astype(np.float32)
epsilon_value = tf_op.get_attr(tf_epsilon_str)
scale_name = self.get_scope(tf_op.name) + '/scale:0'
offset_name = self.get_scope(tf_op.name) + '/offset:0'
scale_value = (
(1.0 / np.vectorize(math.sqrt)(
var_value + epsilon_value)) * gamma_value)
offset_value = (-mean_value * scale_value) + beta_value
self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
scale_value)
self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT,
offset_value)
self._skip_tensor.update([inp.name for inp in tf_op.inputs][1:])
del op.input[1:]
op.input.extend([scale_name, offset_name])
del op.output[1:]
del op.output_shape[1:]
del op.output_type[1:]
def convert_pooling(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Pooling.name
pooling_type_arg = op.arg.add()
pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
pooling_type_arg.i = self.pooling_type_mode[tf_op.type].value
padding_arg = op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_str
padding_arg.i = self.padding_mode[tf_op.get_attr(tf_padding_str)].value
strides_arg = op.arg.add()
strides_arg.name = MaceKeyword.mace_strides_str
strides_arg.ints.extend(tf_op.get_attr(tf_strides_str)[1:3])
kernels_arg = op.arg.add()
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(tf_op.get_attr(tf_kernel_str)[1:3])
def convert_softmax(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Softmax.name
def convert_resize_bilinear(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.ResizeBilinear.name
del op.input[1:]
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_resize_size_str
size_value = tf_op.inputs[1].eval().astype(np.int32)
size_arg.ints.extend(size_value)
self._skip_tensor.update(tf_op.inputs[1].name)
align_corners_arg = op.arg.add()
align_corners_arg.name = MaceKeyword.mace_align_corners_str
align_corners_arg.i = tf_op.get_attr(tf_align_corners)
def convert_space_batch(self, tf_op):
print """You might want to try 'flatten_atrous_conv' in
transform graph to turn atrous conv2d into regular conv2d.
This may give you performance benefit on GPU.
(see https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/tools/graph_transforms/README.md#flatten_atrous_conv)
"""
op = self.convert_general_op(tf_op)
del op.input[1:]
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_space_batch_block_shape_str
size_value = tf_op.inputs[1].eval().astype(np.int32)
size_arg.ints.extend(size_value)
crops_or_paddings_arg = op.arg.add()
if op.type == 'BatchToSpaceND':
op.type = MaceOp.BatchToSpaceND.name
crops_or_paddings_arg.name = \
MaceKeyword.mace_batch_to_space_crops_str
else:
op.type = MaceOp.SpaceToBatchND.name
crops_or_paddings_arg.name = MaceKeyword.mace_paddings_str
crops_or_paddings_value = tf_op.inputs[2].eval().astype(np.int32).flat
crops_or_paddings_arg.ints.extend(crops_or_paddings_value)
self._skip_tensor.update(tf_op.inputs[1].name)
self._skip_tensor.update(tf_op.inputs[2].name)
def convert_space_depth(self, tf_op):
op = self.convert_general_op(tf_op)
if op.type == 'SpaceToDepth':
op.type = MaceOp.SpaceToDepth.name
else:
op.type = MaceOp.DepthToSpace.name
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_space_depth_block_size_str
size_arg.i = tf_op.get_attr(tf_block_size)
def convert_pad(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Pad.name
del op.input[1:]
paddings_arg = op.arg.add()
paddings_arg.name = MaceKeyword.mace_paddings_str
paddings_value = tf_op.inputs[1].eval().astype(np.int32).flat
paddings_arg.ints.extend(paddings_value)
self._skip_tensor.update(tf_op.inputs[1].name)
if len(tf_op.inputs) == 3:
constant_value_arg = op.arg.add()
constant_value_arg.name = MaceKeyword.mace_constant_value_str
constant_value = tf_op.inputs[2].eval().astype(np.int32).flat[0]
constant_value_arg.i = constant_value
self._skip_tensor.update(tf_op.inputs[2].name)
def convert_concat(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Concat.name
del op.input[-1]
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis = tf_op.inputs[-1].eval().astype(np.int32)
axis_arg.i = axis
mace_check(axis == 3, "only support concat at channel dimension")
self._skip_tensor.update(tf_op.inputs[-1].name)
def convert_reshape(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Reshape.name
del op.input[1:]
shape_arg = op.arg.add()
shape_arg.name = MaceKeyword.mace_shape_str
shape_value = []
if tf_op.inputs[1].op.type == 'Const':
shape_value = list(tf_op.inputs[1].eval().astype(np.int32))
for i in xrange(len(shape_value)):
if shape_value[i] == -1:
shape_value[i] = 1
self._skip_tensor.update(tf_op.inputs[-1].name)
elif tf_op.inputs[1].op.type == 'Shape':
shape_value = list(tf_op.inputs[1].op.inputs[0].shape.as_list())
shape_arg.ints.extend(shape_value)
def convert_mean(self, tf_op):
op = self.convert_general_op(tf_op)
del op.input[1:]
reduce_dims = tf_op.inputs[1].eval()
mace_check(reduce_dims[0] == 1 and reduce_dims[1] == 2,
"Mean only support reduce dim 1, 2")
op.type = MaceOp.Pooling.name
pooling_type_arg = op.arg.add()
pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
pooling_type_arg.i = PoolingType.AVG.value
padding_arg = op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_str
padding_arg.i = PaddingMode.VALID.value
strides_arg = op.arg.add()
strides_arg.name = MaceKeyword.mace_strides_str
strides_arg.ints.extend([1, 1])
kernels_arg = op.arg.add()
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(tf_op.inputs[0].shape.as_list()[1:3])
self._skip_tensor.add(tf_op.inputs[1].name)
此差异已折叠。
......@@ -129,7 +129,7 @@ class MemoryOptimizer(object):
self.idle_mem.remove(mem_id)
if mem_id == -1:
mem_id = self.total_mem_count
mem_id = self.mem_id_base() + self.total_mem_count
self.total_mem_count += 1
self.mem_block[mem_id] = op_mem_block
......@@ -147,10 +147,13 @@ class MemoryOptimizer(object):
self.add_net_mem_blocks()
print('total op: %d', len(self.net_def.op))
print('origin mem: %d, optimized mem: %d',
print("total op: %d" % len(self.net_def.op))
print("origin mem: %d, optimized mem: %d" % (
self.get_total_origin_mem_size(),
self.get_total_optimized_mem_size())
self.get_total_optimized_mem_size()))
def mem_id_base(self):
return 0
class GPUMemoryOptimizer(MemoryOptimizer):
......@@ -189,6 +192,9 @@ class GPUMemoryOptimizer(MemoryOptimizer):
block.x = self.mem_block[mem][0]
block.y = self.mem_block[mem][1]
def mem_id_base(self):
return 20000
def optimize_gpu_memory(net_def):
mem_optimizer = GPUMemoryOptimizer(net_def)
......
......@@ -84,11 +84,20 @@ def obfuscate_name(net_def):
op.output[i] = in_out_map[op.output[i]]
def normalize_op_name(op_name):
idx = op_name.rfind(':')
if idx == -1:
return op_name
else:
return op_name[:idx]
def rename_tensor(net_def):
tensor_map = {}
for t in net_def.tensors:
if t.name not in tensor_map:
tensor_map[t.name] = "_" + t.name[:-2].replace("/", "_")
tensor_map[t.name] = "_" + normalize_op_name(t.name).replace("/",
"_")
t.name = tensor_map[t.name]
for op in net_def.op:
for i in range(len(op.input)):
......@@ -118,6 +127,8 @@ class TensorInfo:
elif t.data_type == mace_pb2.DT_UINT8:
self.data = bytearray(
np.array(t.int32_data).astype(np.uint8).tolist())
else:
raise Exception('Tensor data type %s not supported' % t.data_type)
def stringfy(value):
......
此差异已折叠。
......@@ -152,7 +152,7 @@ void CheckOutputs(const NetDef &net_def,
memcpy(input_data.data(), input.second.data().get(),
data_size * sizeof(float));
std::string input_name = MakeString("mace_input_node_",
input.first, ":0");
input.first);
net.AddInputFromArray<D, float>(input_name, input.second.shape(),
input_data);
}
......@@ -181,7 +181,7 @@ void CheckOutputs(const NetDef &net_def,
float *data = tmp_tensor->mutable_data<float>();
memcpy(data, output.second.data().get(), data_size * sizeof(float));
std::string output_name = MakeString("mace_output_node_",
output.first, ":0");
output.first);
ops::test::ExpectTensorNear<float>(*tmp_tensor,
*net.GetOutput(output_name.data()),
1e-5);
......@@ -265,7 +265,7 @@ void MaceRunFunc(const int in_out_size) {
for (size_t i = 0; i < input_names.size(); ++i) {
std::string input_name = MakeString("mace_input_node_",
input_names[i], ":0");
input_names[i]);
BufferToImage<half>(input_name, input_names[i],
mace::kernels::IN_OUT_CHANNEL,
{mem_map[input_names[i]]},
......@@ -281,7 +281,7 @@ void MaceRunFunc(const int in_out_size) {
}
for (size_t i = 0; i < output_names.size(); ++i) {
std::string output_name = MakeString("mace_output_node_",
output_names[i], ":0");
output_names[i]);
ImageToBuffer<float>(output_names[i], output_name,
mace::kernels::IN_OUT_CHANNEL, &net_def);
}
......
......@@ -162,7 +162,7 @@ void CheckOutputs(const NetDef &net_def,
memcpy(input_data.data(), input.second.data().get(),
data_size * sizeof(float));
std::string input_name = MakeString("mace_input_node_",
input.first, ":0");
input.first);
net.AddInputFromArray<D, float>(input_name, input.second.shape(),
input_data);
}
......@@ -191,7 +191,7 @@ void CheckOutputs(const NetDef &net_def,
float *data = tmp_tensor->mutable_data<float>();
memcpy(data, output.second.data().get(), data_size * sizeof(float));
std::string output_name = MakeString("mace_output_node_",
output.first, ":0");
output.first);
ops::test::ExpectTensorNear<float>(*tmp_tensor,
*net.GetOutput(output_name.data()),
1e-5);
......@@ -275,7 +275,7 @@ void MaceRun(const int in_out_size,
for (size_t i = 0; i < input_names.size(); ++i) {
std::string input_name = MakeString("mace_input_node_",
input_names[i], ":0");
input_names[i]);
BufferToImage<half>(input_name, input_names[i],
mace::kernels::IN_OUT_CHANNEL,
{mem_map[input_names[i]]},
......@@ -291,7 +291,7 @@ void MaceRun(const int in_out_size,
}
for (size_t i = 0; i < output_names.size(); ++i) {
std::string output_name = MakeString("mace_output_node_",
output_names[i], ":0");
output_names[i]);
ImageToBuffer<float>(output_names[i], output_name,
mace::kernels::IN_OUT_CHANNEL, &net_def);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册