From aabcf3d872fc423fbf49fdb0ff71a027071fc1b4 Mon Sep 17 00:00:00 2001 From: oneTaken Date: Thu, 20 Sep 2018 09:58:04 +0800 Subject: [PATCH] add friendly check for batch normalization layer to better show what's wrong (#199) * add friendly check for batch normalization layer * coding style * reduce the too long line length --- mace/python/tools/converter_tool/tensorflow_converter.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 2e373138..f6f0877d 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -39,6 +39,7 @@ tf_dilations_str = 'dilations' tf_data_format_str = 'data_format' tf_kernel_str = 'ksize' tf_epsilon_str = 'epsilon' +tf_is_training_str = 'is_training' tf_align_corners = 'align_corners' tf_block_size = 'block_size' tf_squeeze_dims = 'squeeze_dims' @@ -516,6 +517,10 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.FoldedBatchNorm.name + is_training = tf_op.get_attr(tf_is_training_str) + assert is_training is False, 'Only support batch normalization ' \ + 'with is_training False, but got %s' % is_training + gamma_value = tf_op.inputs[1].eval().astype(np.float32) beta_value = tf_op.inputs[2].eval().astype(np.float32) mean_value = tf_op.inputs[3].eval().astype(np.float32) -- GitLab