提交 1140132d 编写于 作者: L liuqi

Support tensorflow fc layer.

上级 fc1c855e
...@@ -210,7 +210,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -210,7 +210,7 @@ class Transformer(base_converter.ConverterInterface):
return False return False
def safe_remove_node(self, op, replace_op): def safe_remove_node(self, op, replace_op, remove_input_tensor=False):
"""remove op. """remove op.
1. change the inputs of its consumers to the outputs of replace_op 1. change the inputs of its consumers to the outputs of replace_op
2. if the op is output node, change output node to replace op""" 2. if the op is output node, change output node to replace op"""
...@@ -250,6 +250,12 @@ class Transformer(base_converter.ConverterInterface): ...@@ -250,6 +250,12 @@ class Transformer(base_converter.ConverterInterface):
op.output[i]) op.output[i])
replace_op.output[i] = op.output[i] replace_op.output[i] = op.output[i]
if remove_input_tensor:
for input_name in op.input:
if input_name in self._consts:
const_tensor = self._consts[input_name]
self._model.tensors.remove(const_tensor)
self._model.op.remove(op) self._model.op.remove(op)
def add_in_out_tensor_info(self): def add_in_out_tensor_info(self):
...@@ -903,6 +909,14 @@ class Transformer(base_converter.ConverterInterface): ...@@ -903,6 +909,14 @@ class Transformer(base_converter.ConverterInterface):
filter_data = filter_data.transpose(3, 2, 0, 1) filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
if len(weight.dims) == 4:
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight_data = weight_data.transpose(3, 2, 0, 1)
weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape
self.set_filter_format(FilterFormat.OIHW) self.set_filter_format(FilterFormat.OIHW)
...@@ -914,12 +928,13 @@ class Transformer(base_converter.ConverterInterface): ...@@ -914,12 +928,13 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op: for op in net.op:
if op.type == MaceOp.FullyConnected.name: if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]] weight = self._consts[op.input[1]]
input_op = self._producer[op.input[0]] if len(weight.dims) == 2:
input_shape = list(input_op.output_shape[0].dims) input_op = self._producer[op.input[0]]
input_data_format = ConverterUtil.data_format(input_op) input_shape = list(input_op.output_shape[0].dims)
weight.dims[:] = [weight.dims[0]] + input_shape[1:] input_data_format = ConverterUtil.data_format(input_op)
if input_data_format == DataFormat.NHWC: weight.dims[:] = [weight.dims[0]] + input_shape[1:]
self.transpose_shape(weight.dims, [0, 3, 1, 2]) if input_data_format == DataFormat.NHWC:
self.transpose_shape(weight.dims, [0, 3, 1, 2])
return False return False
...@@ -1073,34 +1088,57 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1073,34 +1088,57 @@ class Transformer(base_converter.ConverterInterface):
and self._producer[consumer.input[1]].type and self._producer[consumer.input[1]].type
== 'Shape'): == 'Shape'):
self.safe_remove_node( self.safe_remove_node(
self._producer[consumer.input[1]], None) self._producer[consumer.input[1]], None,
remove_input_tensor=True)
# remove consumer reshape # remove consumer reshape
self.safe_remove_node(consumer, op) self.safe_remove_node(consumer, op,
remove_input_tensor=True)
# remove producer reshape # remove producer reshape
self.safe_remove_node(producer, self.safe_remove_node(producer,
self._producer.get(producer.input[0], self._producer.get(producer.input[0],
None)) None),
remove_input_tensor=True)
return True return True
return False return False
def transform_matmul_to_fc(self): def transform_matmul_to_fc(self):
net = self._model net = self._model
filter_format = self.filter_format()
for op in net.op: for op in net.op:
if op.type == MaceOp.MatMul.name: # transform reshape + matmul -> fc
input_shape = self.get_tensor_shape(op.input[0]) # work for TensorFlow
if len(input_shape) == 4: if op.type == MaceOp.MatMul.name and \
_, h, w, _ = self.sort_feature_map_shape(input_shape, filter_format == FilterFormat.HWIO:
ConverterUtil.data_format(self._producer[op.input[0]])) # noqa producer = self._producer[op.input[0]]
if h == 1 and w == 1 and op.input[1] in self._consts: weight = self._consts[op.input[1]]
weight = self._consts[op.input[1]] if len(weight.dims) == 2 \
if len(weight.dims) == 2: and producer.type == MaceOp.Reshape.name \
op.type = MaceOp.FullyConnected.name and len(producer.output) == 1 \
weight_data = np.array(weight.float_data).reshape( and producer.input[1] in self._consts \
weight.dims) and len(producer.output_shape[0].dims) == 2:
weight_data = weight_data.transpose(1, 0) input_op = self._producer[producer.input[0]]
weight.float_data[:] = weight_data.flat input_shape = input_op.output_shape[0].dims
weight.dims[:] = weight_data.shape feature_size = np.prod(input_shape[1:])
self.safe_remove_node(producer, input_op,
remove_input_tensor=True)
if feature_size == producer.output_shape[0].dims[1]:
print 'convert reshape and matmul to fc'
op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight.dims[:] = input_shape[1:] + \
[weight_data.shape[1]]
return True
elif len(weight.dims) == 2 and \
len(producer.output_shape[0].dims) == 2 and \
weight.dims[0] == producer.output_shape[0].dims[1]:
print 'convert matmul to fc'
op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight.dims[:] = [1, 1] + list(weight_data.shape)
return True
return False return False
...@@ -1139,9 +1177,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1139,9 +1177,6 @@ class Transformer(base_converter.ConverterInterface):
print("transform global conv to fc %s(%s)" print("transform global conv to fc %s(%s)"
% (op.name, op.type)) % (op.name, op.type))
op.type = MaceOp.FullyConnected.name op.type = MaceOp.FullyConnected.name
filter.dims[:] = [out_channels,
in_channels * filter_width
* filter_height][:]
def add_device(self): def add_device(self):
# TODO(liuqi) add device definition in OperatorDef # TODO(liuqi) add device definition in OperatorDef
......
...@@ -9,8 +9,12 @@ def _git_version_conf_impl(repository_ctx): ...@@ -9,8 +9,12 @@ def _git_version_conf_impl(repository_ctx):
generated_files_path = repository_ctx.path("gen") generated_files_path = repository_ctx.path("gen")
unused_var = repository_ctx.path(Label("//:.git/HEAD")) ret = repository_ctx.execute(
unused_var = repository_ctx.path(Label("//:.git/refs/heads/master")) ["test", "-f", "%s/.git/logs/HEAD" % mace_root_path])
if ret.return_code == 0:
unused_var = repository_ctx.path(Label("//:.git/HEAD"))
unused_var = repository_ctx.path(Label("//:.git/refs/heads/master"))
repository_ctx.execute([ repository_ctx.execute([
'bash', '%s/mace/tools/git/gen_version_source.sh' % mace_root_path 'bash', '%s/mace/tools/git/gen_version_source.sh' % mace_root_path
......
...@@ -5,8 +5,15 @@ def _opencl_encrypt_kernel_impl(repository_ctx): ...@@ -5,8 +5,15 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
"BUILD", "BUILD",
Label("//repository/opencl-kernel:BUILD.tpl")) Label("//repository/opencl-kernel:BUILD.tpl"))
unused_var = repository_ctx.path(Label("//:.git/HEAD")) mace_root_path = str(repository_ctx.path(Label("@mace//:BUILD")))[:-len("BUILD")]
unused_var = repository_ctx.path(Label("//:.git/refs/heads/master"))
ret = repository_ctx.execute(
["test", "-f", "%s/.git/logs/HEAD" % mace_root_path])
if ret.return_code == 0:
unused_var = repository_ctx.path(Label("//:.git/HEAD"))
unused_var = repository_ctx.path(Label("//:.git/refs/heads/master"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/activation.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/activation.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/addn.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/addn.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/batch_norm.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/batch_norm.cl"))
...@@ -33,7 +40,6 @@ def _opencl_encrypt_kernel_impl(repository_ctx): ...@@ -33,7 +40,6 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl"))
mace_root_path = str(repository_ctx.path(Label("@mace//:BUILD")))[:-len("BUILD")]
generated_files_path = repository_ctx.path("gen") generated_files_path = repository_ctx.path("gen")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册