diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 2e373138be34d386c002234dda5dacc0bee7297e..f6f0877de9b6ecda45a3117d2b22241ad0306203 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -39,6 +39,7 @@ tf_dilations_str = 'dilations' tf_data_format_str = 'data_format' tf_kernel_str = 'ksize' tf_epsilon_str = 'epsilon' +tf_is_training_str = 'is_training' tf_align_corners = 'align_corners' tf_block_size = 'block_size' tf_squeeze_dims = 'squeeze_dims' @@ -516,6 +517,10 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.FoldedBatchNorm.name + is_training = tf_op.get_attr(tf_is_training_str) + assert is_training is False, 'Only support batch normalization ' \ + 'with is_training False, but got %s' % is_training + gamma_value = tf_op.inputs[1].eval().astype(np.float32) beta_value = tf_op.inputs[2].eval().astype(np.float32) mean_value = tf_op.inputs[3].eval().astype(np.float32)