提交 997b06f7 编写于 作者: Y yejianwu

fix fold bn to conv, deconv and depthwise

上级 bfbe1a30
......@@ -91,7 +91,7 @@ MaceStatus OpenCLAllocator::NewImage(const std::vector<size_t> &image_shape,
if (error != CL_SUCCESS) {
LOG(WARNING) << "Allocate OpenCL image with shape: ["
<< image_shape[0] << ", " << image_shape[1]
<< "] failed because of"
<< "] failed because of "
<< OpenCLErrorToString(error);
delete cl_image;
*result = nullptr;
......
......@@ -415,12 +415,12 @@ class ConverterOption(object):
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.FOLD_BIASADD,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DECONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.REARRANGE_BATCH_TO_SPACE,
TransformerRule.FOLD_BIASADD,
TransformerRule.FLATTEN_ATROUS_CONV,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.FOLD_SQRDIFF_MEAN,
......
......@@ -57,6 +57,7 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.TRANSFORM_MATMUL_TO_FC:
self.transform_matmul_to_fc,
TransformerRule.FOLD_BATCHNORM: self.fold_batchnorm,
TransformerRule.FOLD_BIASADD: self.fold_biasadd,
TransformerRule.FOLD_CONV_AND_BN:
self.fold_conv_and_bn, # data_format related
TransformerRule.FOLD_DECONV_AND_BN:
......@@ -67,7 +68,6 @@ class Transformer(base_converter.ConverterInterface):
self.transform_add_to_biasadd,
TransformerRule.REARRANGE_BATCH_TO_SPACE:
self.rearrange_batch_to_space,
TransformerRule.FOLD_BIASADD: self.fold_biasadd,
TransformerRule.FLATTEN_ATROUS_CONV: self.flatten_atrous_conv,
TransformerRule.FOLD_ACTIVATION: self.fold_activation,
TransformerRule.FOLD_SQRDIFF_MEAN: self.fold_squared_diff_mean,
......@@ -546,10 +546,14 @@ class Transformer(base_converter.ConverterInterface):
if (op.type == MaceOp.Conv2D.name) \
and self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.BatchNorm.name:
input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and \
(input_len == 2 or
(input_len == 3 and op.input[-1] in self._consts)):
print("Fold conv and bn: %s(%s)" % (op.name, op.type))
filter = self._consts[op.input[1]]
scale = self._consts[consumer_op.input[1]]
offset = self._consts[consumer_op.input[2]]
idx = 0
filter_format = self.filter_format()
if filter_format == FilterFormat.HWIO:
......@@ -570,12 +574,20 @@ class Transformer(base_converter.ConverterInterface):
mace_check(False, "filter format %s not supported" %
filter_format)
# change BN to BiasAdd
consumer_op.type = MaceOp.BiasAdd.name
del consumer_op.input[1]
if len(op.input) == 3:
conv_bias = self._consts[op.input[2]]
for c in six.moves.range(conv_bias.dims[0]):
conv_bias.float_data[c] *= scale.float_data[c]
conv_bias.float_data[c] += offset.float_data[c]
net.tensors.remove(offset)
else:
op.input.extend([consumer_op.input[2]])
# remove scale tensor
# remove bn
del consumer_op.input[:]
net.tensors.remove(scale)
self.safe_remove_node(consumer_op, op)
return True
return False
......@@ -586,10 +598,21 @@ class Transformer(base_converter.ConverterInterface):
if (op.type in [MaceOp.Deconv2D.name, MaceOp.DepthwiseDeconv2d]) \
and self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.BatchNorm.name:
framework = ConverterUtil.get_arg(
op, MaceKeyword.mace_framework_type_str).i
input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and \
(framework == FrameworkType.CAFFE.value and
(input_len == 2 or
(input_len == 3 and
op.input[-1] in self._consts))) or \
(framework == FrameworkType.TENSORFLOW.value and
(input_len == 3 or (input_len == 4 and
op.input[-1] in self._consts))):
print("Fold deconv and bn: %s(%s)" % (op.name, op.type))
filter = self._consts[op.input[1]]
scale = self._consts[consumer_op.input[1]]
offset = self._consts[consumer_op.input[2]]
idx = 0
filter_format = self.filter_format()
# in deconv op O and I channel is switched
......@@ -613,12 +636,27 @@ class Transformer(base_converter.ConverterInterface):
mace_check(False, "filter format %s not supported" %
filter_format)
# change BN to BiasAdd
consumer_op.type = MaceOp.BiasAdd.name
del consumer_op.input[1]
bias_dim = -1
if framework == FrameworkType.CAFFE.value \
and len(op.input) == 3:
bias_dim = 2
if framework == FrameworkType.TENSORFLOW.value \
and len(op.input) == 4:
bias_dim = 3
if bias_dim != -1:
conv_bias = self._consts[op.input[bias_dim]]
for c in six.moves.range(conv_bias.dims[0]):
conv_bias.float_data[c] *= scale.float_data[c]
conv_bias.float_data[c] += offset.float_data[c]
net.tensors.remove(offset)
else:
op.input.extend([consumer_op.input[2]])
# remove scale tensor
del consumer_op.input[:]
net.tensors.remove(scale)
self.safe_remove_node(consumer_op, op)
return True
return False
......@@ -629,11 +667,15 @@ class Transformer(base_converter.ConverterInterface):
if op.type == MaceOp.DepthwiseConv2d.name \
and self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.BatchNorm.name:
input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and \
(input_len == 2 or
(input_len == 3 and op.input[-1] in self._consts)):
print("Fold depthwise conv and bn: %s(%s)"
% (op.name, op.type))
filter = self._consts[op.input[1]]
scale = self._consts[consumer_op.input[1]]
offset = self._consts[consumer_op.input[2]]
idx = 0
filter_format = self.filter_format()
......@@ -657,12 +699,20 @@ class Transformer(base_converter.ConverterInterface):
mace_check(False, "filter format %s not supported" %
filter_format)
# change BN to BiasAdd
consumer_op.type = MaceOp.BiasAdd.name
del consumer_op.input[1]
if len(op.input) == 3:
conv_bias = self._consts[op.input[2]]
for c in six.moves.range(conv_bias.dims[0]):
conv_bias.float_data[c] *= scale.float_data[c]
conv_bias.float_data[c] += offset.float_data[c]
net.tensors.remove(offset)
else:
op.input.extend([consumer_op.input[2]])
# remove scale tensor
# remove bn
del consumer_op.input[:]
net.tensors.remove(scale)
self.safe_remove_node(consumer_op, op)
return True
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册