diff --git a/mace/python/tools/convert_util.py b/mace/python/tools/convert_util.py index 2a685d322148a055fcec1a44d6dc52b09974a4ac..18791987f8f6ee9bdedca6a4107570d1a01a900a 100644 --- a/mace/python/tools/convert_util.py +++ b/mace/python/tools/convert_util.py @@ -13,26 +13,6 @@ # limitations under the License. -import tensorflow as tf -from mace.proto import mace_pb2 - -TF_DTYPE_2_MACE_DTYPE_MAP = { - tf.float32: mace_pb2.DT_FLOAT, - tf.half: mace_pb2.DT_HALF, - tf.int32: mace_pb2.DT_INT32, - tf.qint32: mace_pb2.DT_INT32, - tf.quint8: mace_pb2.DT_UINT8, - tf.uint8: mace_pb2.DT_UINT8, -} - - -def tf_dtype_2_mace_dtype(tf_dtype): - mace_dtype = TF_DTYPE_2_MACE_DTYPE_MAP.get(tf_dtype, None) - if not mace_dtype: - raise Exception("Not supported tensorflow dtype: " + tf_dtype) - return mace_dtype - - def mace_check(condition, msg): if not condition: raise Exception(msg) diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index fb5b8753ef0f4c604ea32e5f54b1a1c0da89eb7a..1da642629f97003456ba73a0a9a32ad7b7e8de6e 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -19,12 +19,9 @@ import os.path import copy from mace.proto import mace_pb2 -from mace.python.tools import tf_dsp_converter_lib from mace.python.tools import memory_optimizer from mace.python.tools import model_saver from mace.python.tools.converter_tool import base_converter as cvt -from mace.python.tools.converter_tool import tensorflow_converter -from mace.python.tools.converter_tool import caffe_converter from mace.python.tools.converter_tool import transformer from mace.python.tools.convert_util import mace_check @@ -101,6 +98,7 @@ def main(unused_args): if FLAGS.runtime == 'dsp': if FLAGS.platform == 'tensorflow': + from mace.python.tools import tf_dsp_converter_lib output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) @@ -132,9 +130,11 @@ def main(unused_args): option.add_output_node(output_node) if FLAGS.platform == 'tensorflow': + from mace.python.tools.converter_tool import tensorflow_converter converter = tensorflow_converter.TensorflowConverter( option, FLAGS.model_file) elif FLAGS.platform == 'caffe': + from mace.python.tools.converter_tool import caffe_converter converter = caffe_converter.CaffeConverter(option, FLAGS.model_file, FLAGS.weight_file) diff --git a/mace/python/tools/tf_dsp_converter_lib.py b/mace/python/tools/tf_dsp_converter_lib.py index 87bc92d0e3fea2384f2d733dde7bb3bc14e810dc..30236d5ce29c102eca3ce1140aa0221aca37f5cb 100644 --- a/mace/python/tools/tf_dsp_converter_lib.py +++ b/mace/python/tools/tf_dsp_converter_lib.py @@ -19,13 +19,29 @@ from tensorflow import gfile from operator import mul from dsp_ops import DspOps from mace.python.tools import graph_util -from mace.python.tools.convert_util import tf_dtype_2_mace_dtype # converter --input ../libcv/quantized_model.pb \ # --output quantized_model_dsp.pb \ # --runtime dsp --input_node input_node \ # --output_node output_node +TF_DTYPE_2_MACE_DTYPE_MAP = { + tf.float32: mace_pb2.DT_FLOAT, + tf.half: mace_pb2.DT_HALF, + tf.int32: mace_pb2.DT_INT32, + tf.qint32: mace_pb2.DT_INT32, + tf.quint8: mace_pb2.DT_UINT8, + tf.uint8: mace_pb2.DT_UINT8, +} + + +def tf_dtype_2_mace_dtype(tf_dtype): + mace_dtype = TF_DTYPE_2_MACE_DTYPE_MAP.get(tf_dtype, None) + if not mace_dtype: + raise Exception("Not supported tensorflow dtype: " + tf_dtype) + return mace_dtype + + padding_mode = { 'NA': 0, 'SAME': 1,