提交 217b3fd0 编写于 作者: L liuqi

Fix bug: need tensorflow when convert caffe model.

上级 82bf1768
......@@ -13,26 +13,6 @@
# limitations under the License.
import tensorflow as tf
from mace.proto import mace_pb2
TF_DTYPE_2_MACE_DTYPE_MAP = {
tf.float32: mace_pb2.DT_FLOAT,
tf.half: mace_pb2.DT_HALF,
tf.int32: mace_pb2.DT_INT32,
tf.qint32: mace_pb2.DT_INT32,
tf.quint8: mace_pb2.DT_UINT8,
tf.uint8: mace_pb2.DT_UINT8,
}
def tf_dtype_2_mace_dtype(tf_dtype):
mace_dtype = TF_DTYPE_2_MACE_DTYPE_MAP.get(tf_dtype, None)
if not mace_dtype:
raise Exception("Not supported tensorflow dtype: " + tf_dtype)
return mace_dtype
def mace_check(condition, msg):
if not condition:
raise Exception(msg)
......@@ -19,12 +19,9 @@ import os.path
import copy
from mace.proto import mace_pb2
from mace.python.tools import tf_dsp_converter_lib
from mace.python.tools import memory_optimizer
from mace.python.tools import model_saver
from mace.python.tools.converter_tool import base_converter as cvt
from mace.python.tools.converter_tool import tensorflow_converter
from mace.python.tools.converter_tool import caffe_converter
from mace.python.tools.converter_tool import transformer
from mace.python.tools.convert_util import mace_check
......@@ -101,6 +98,7 @@ def main(unused_args):
if FLAGS.runtime == 'dsp':
if FLAGS.platform == 'tensorflow':
from mace.python.tools import tf_dsp_converter_lib
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, FLAGS.output_node,
FLAGS.dsp_mode)
......@@ -132,9 +130,11 @@ def main(unused_args):
option.add_output_node(output_node)
if FLAGS.platform == 'tensorflow':
from mace.python.tools.converter_tool import tensorflow_converter
converter = tensorflow_converter.TensorflowConverter(
option, FLAGS.model_file)
elif FLAGS.platform == 'caffe':
from mace.python.tools.converter_tool import caffe_converter
converter = caffe_converter.CaffeConverter(option,
FLAGS.model_file,
FLAGS.weight_file)
......
......@@ -19,13 +19,29 @@ from tensorflow import gfile
from operator import mul
from dsp_ops import DspOps
from mace.python.tools import graph_util
from mace.python.tools.convert_util import tf_dtype_2_mace_dtype
# converter --input ../libcv/quantized_model.pb \
# --output quantized_model_dsp.pb \
# --runtime dsp --input_node input_node \
# --output_node output_node
TF_DTYPE_2_MACE_DTYPE_MAP = {
tf.float32: mace_pb2.DT_FLOAT,
tf.half: mace_pb2.DT_HALF,
tf.int32: mace_pb2.DT_INT32,
tf.qint32: mace_pb2.DT_INT32,
tf.quint8: mace_pb2.DT_UINT8,
tf.uint8: mace_pb2.DT_UINT8,
}
def tf_dtype_2_mace_dtype(tf_dtype):
mace_dtype = TF_DTYPE_2_MACE_DTYPE_MAP.get(tf_dtype, None)
if not mace_dtype:
raise Exception("Not supported tensorflow dtype: " + tf_dtype)
return mace_dtype
padding_mode = {
'NA': 0,
'SAME': 1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册