From 217b3fd066395196b1aef9b4b4aa69d7e3ce916e Mon Sep 17 00:00:00 2001 From: liuqi Date: Tue, 3 Jul 2018 17:19:56 +0800 Subject: [PATCH] Fix bug: need tensorflow when convert caffe model. --- mace/python/tools/convert_util.py | 20 -------------------- mace/python/tools/converter.py | 6 +++--- mace/python/tools/tf_dsp_converter_lib.py | 18 +++++++++++++++++- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/mace/python/tools/convert_util.py b/mace/python/tools/convert_util.py index 2a685d32..18791987 100644 --- a/mace/python/tools/convert_util.py +++ b/mace/python/tools/convert_util.py @@ -13,26 +13,6 @@ # limitations under the License. -import tensorflow as tf -from mace.proto import mace_pb2 - -TF_DTYPE_2_MACE_DTYPE_MAP = { - tf.float32: mace_pb2.DT_FLOAT, - tf.half: mace_pb2.DT_HALF, - tf.int32: mace_pb2.DT_INT32, - tf.qint32: mace_pb2.DT_INT32, - tf.quint8: mace_pb2.DT_UINT8, - tf.uint8: mace_pb2.DT_UINT8, -} - - -def tf_dtype_2_mace_dtype(tf_dtype): - mace_dtype = TF_DTYPE_2_MACE_DTYPE_MAP.get(tf_dtype, None) - if not mace_dtype: - raise Exception("Not supported tensorflow dtype: " + tf_dtype) - return mace_dtype - - def mace_check(condition, msg): if not condition: raise Exception(msg) diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index fb5b8753..1da64262 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -19,12 +19,9 @@ import os.path import copy from mace.proto import mace_pb2 -from mace.python.tools import tf_dsp_converter_lib from mace.python.tools import memory_optimizer from mace.python.tools import model_saver from mace.python.tools.converter_tool import base_converter as cvt -from mace.python.tools.converter_tool import tensorflow_converter -from mace.python.tools.converter_tool import caffe_converter from mace.python.tools.converter_tool import transformer from mace.python.tools.convert_util import mace_check @@ -101,6 +98,7 @@ def main(unused_args): if FLAGS.runtime == 'dsp': if FLAGS.platform == 'tensorflow': + from mace.python.tools import tf_dsp_converter_lib output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) @@ -132,9 +130,11 @@ def main(unused_args): option.add_output_node(output_node) if FLAGS.platform == 'tensorflow': + from mace.python.tools.converter_tool import tensorflow_converter converter = tensorflow_converter.TensorflowConverter( option, FLAGS.model_file) elif FLAGS.platform == 'caffe': + from mace.python.tools.converter_tool import caffe_converter converter = caffe_converter.CaffeConverter(option, FLAGS.model_file, FLAGS.weight_file) diff --git a/mace/python/tools/tf_dsp_converter_lib.py b/mace/python/tools/tf_dsp_converter_lib.py index 87bc92d0..30236d5c 100644 --- a/mace/python/tools/tf_dsp_converter_lib.py +++ b/mace/python/tools/tf_dsp_converter_lib.py @@ -19,13 +19,29 @@ from tensorflow import gfile from operator import mul from dsp_ops import DspOps from mace.python.tools import graph_util -from mace.python.tools.convert_util import tf_dtype_2_mace_dtype # converter --input ../libcv/quantized_model.pb \ # --output quantized_model_dsp.pb \ # --runtime dsp --input_node input_node \ # --output_node output_node +TF_DTYPE_2_MACE_DTYPE_MAP = { + tf.float32: mace_pb2.DT_FLOAT, + tf.half: mace_pb2.DT_HALF, + tf.int32: mace_pb2.DT_INT32, + tf.qint32: mace_pb2.DT_INT32, + tf.quint8: mace_pb2.DT_UINT8, + tf.uint8: mace_pb2.DT_UINT8, +} + + +def tf_dtype_2_mace_dtype(tf_dtype): + mace_dtype = TF_DTYPE_2_MACE_DTYPE_MAP.get(tf_dtype, None) + if not mace_dtype: + raise Exception("Not supported tensorflow dtype: " + tf_dtype) + return mace_dtype + + padding_mode = { 'NA': 0, 'SAME': 1, -- GitLab