提交 1140132d 编写于 作者: L liuqi

Support tensorflow fc layer.

上级 fc1c855e
......@@ -210,7 +210,7 @@ class Transformer(base_converter.ConverterInterface):
return False
def safe_remove_node(self, op, replace_op):
def safe_remove_node(self, op, replace_op, remove_input_tensor=False):
"""remove op.
1. change the inputs of its consumers to the outputs of replace_op
2. if the op is output node, change output node to replace op"""
......@@ -250,6 +250,12 @@ class Transformer(base_converter.ConverterInterface):
op.output[i])
replace_op.output[i] = op.output[i]
if remove_input_tensor:
for input_name in op.input:
if input_name in self._consts:
const_tensor = self._consts[input_name]
self._model.tensors.remove(const_tensor)
self._model.op.remove(op)
def add_in_out_tensor_info(self):
......@@ -903,6 +909,14 @@ class Transformer(base_converter.ConverterInterface):
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
if len(weight.dims) == 4:
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight_data = weight_data.transpose(3, 2, 0, 1)
weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape
self.set_filter_format(FilterFormat.OIHW)
......@@ -914,12 +928,13 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op:
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
input_op = self._producer[op.input[0]]
input_shape = list(input_op.output_shape[0].dims)
input_data_format = ConverterUtil.data_format(input_op)
weight.dims[:] = [weight.dims[0]] + input_shape[1:]
if input_data_format == DataFormat.NHWC:
self.transpose_shape(weight.dims, [0, 3, 1, 2])
if len(weight.dims) == 2:
input_op = self._producer[op.input[0]]
input_shape = list(input_op.output_shape[0].dims)
input_data_format = ConverterUtil.data_format(input_op)
weight.dims[:] = [weight.dims[0]] + input_shape[1:]
if input_data_format == DataFormat.NHWC:
self.transpose_shape(weight.dims, [0, 3, 1, 2])
return False
......@@ -1073,34 +1088,57 @@ class Transformer(base_converter.ConverterInterface):
and self._producer[consumer.input[1]].type
== 'Shape'):
self.safe_remove_node(
self._producer[consumer.input[1]], None)
self._producer[consumer.input[1]], None,
remove_input_tensor=True)
# remove consumer reshape
self.safe_remove_node(consumer, op)
self.safe_remove_node(consumer, op,
remove_input_tensor=True)
# remove producer reshape
self.safe_remove_node(producer,
self._producer.get(producer.input[0],
None))
None),
remove_input_tensor=True)
return True
return False
def transform_matmul_to_fc(self):
net = self._model
filter_format = self.filter_format()
for op in net.op:
if op.type == MaceOp.MatMul.name:
input_shape = self.get_tensor_shape(op.input[0])
if len(input_shape) == 4:
_, h, w, _ = self.sort_feature_map_shape(input_shape,
ConverterUtil.data_format(self._producer[op.input[0]])) # noqa
if h == 1 and w == 1 and op.input[1] in self._consts:
weight = self._consts[op.input[1]]
if len(weight.dims) == 2:
op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight_data = weight_data.transpose(1, 0)
weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape
# transform reshape + matmul -> fc
# work for TensorFlow
if op.type == MaceOp.MatMul.name and \
filter_format == FilterFormat.HWIO:
producer = self._producer[op.input[0]]
weight = self._consts[op.input[1]]
if len(weight.dims) == 2 \
and producer.type == MaceOp.Reshape.name \
and len(producer.output) == 1 \
and producer.input[1] in self._consts \
and len(producer.output_shape[0].dims) == 2:
input_op = self._producer[producer.input[0]]
input_shape = input_op.output_shape[0].dims
feature_size = np.prod(input_shape[1:])
self.safe_remove_node(producer, input_op,
remove_input_tensor=True)
if feature_size == producer.output_shape[0].dims[1]:
print 'convert reshape and matmul to fc'
op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight.dims[:] = input_shape[1:] + \
[weight_data.shape[1]]
return True
elif len(weight.dims) == 2 and \
len(producer.output_shape[0].dims) == 2 and \
weight.dims[0] == producer.output_shape[0].dims[1]:
print 'convert matmul to fc'
op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight.dims[:] = [1, 1] + list(weight_data.shape)
return True
return False
......@@ -1139,9 +1177,6 @@ class Transformer(base_converter.ConverterInterface):
print("transform global conv to fc %s(%s)"
% (op.name, op.type))
op.type = MaceOp.FullyConnected.name
filter.dims[:] = [out_channels,
in_channels * filter_width
* filter_height][:]
def add_device(self):
# TODO(liuqi) add device definition in OperatorDef
......
......@@ -9,8 +9,12 @@ def _git_version_conf_impl(repository_ctx):
generated_files_path = repository_ctx.path("gen")
unused_var = repository_ctx.path(Label("//:.git/HEAD"))
unused_var = repository_ctx.path(Label("//:.git/refs/heads/master"))
ret = repository_ctx.execute(
["test", "-f", "%s/.git/logs/HEAD" % mace_root_path])
if ret.return_code == 0:
unused_var = repository_ctx.path(Label("//:.git/HEAD"))
unused_var = repository_ctx.path(Label("//:.git/refs/heads/master"))
repository_ctx.execute([
'bash', '%s/mace/tools/git/gen_version_source.sh' % mace_root_path
......
......@@ -5,8 +5,15 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
"BUILD",
Label("//repository/opencl-kernel:BUILD.tpl"))
unused_var = repository_ctx.path(Label("//:.git/HEAD"))
unused_var = repository_ctx.path(Label("//:.git/refs/heads/master"))
mace_root_path = str(repository_ctx.path(Label("@mace//:BUILD")))[:-len("BUILD")]
ret = repository_ctx.execute(
["test", "-f", "%s/.git/logs/HEAD" % mace_root_path])
if ret.return_code == 0:
unused_var = repository_ctx.path(Label("//:.git/HEAD"))
unused_var = repository_ctx.path(Label("//:.git/refs/heads/master"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/activation.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/addn.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/batch_norm.cl"))
......@@ -33,7 +40,6 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl"))
mace_root_path = str(repository_ctx.path(Label("@mace//:BUILD")))[:-len("BUILD")]
generated_files_path = repository_ctx.path("gen")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册