提交 a3075116 编写于 作者: 李寅

Merge branch 'support-tf-split' into 'master'

Support tensorflow split op.

See merge request !710
...@@ -96,6 +96,7 @@ TFSupportedOps = [ ...@@ -96,6 +96,7 @@ TFSupportedOps = [
'Pack', 'Pack',
'Cast', 'Cast',
'ArgMax', 'ArgMax',
'Split',
] ]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
...@@ -192,6 +193,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -192,6 +193,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Stack.name: self.convert_stack, TFOpType.Stack.name: self.convert_stack,
TFOpType.Cast.name: self.convert_cast, TFOpType.Cast.name: self.convert_cast,
TFOpType.ArgMax.name: self.convert_argmax, TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.Split.name: self.convert_split,
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
...@@ -759,3 +761,21 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -759,3 +761,21 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.ArgMax.name op.type = MaceOp.ArgMax.name
op.output_type.extend([mace_pb2.DT_INT32]) op.output_type.extend([mace_pb2.DT_INT32])
def convert_split(self, tf_op):
# inputs: [dim, input]
axis = tf_op.inputs[0].eval().astype(np.int32)
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
mace_check(axis == 3, 'Split with %d axis only support' % axis)
input_shape = self.infer_tensor_shape(tf_op.inputs[1])
mace_check(len(input_shape) == 4 and (input_shape[3] % 4 == 0),
"The input's 4th dimension should be a multiple of 4")
op = self.convert_general_op(tf_op)
op.type = MaceOp.Slice.name
del op.input[0]
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis
self._skip_tensor.add(tf_op.inputs[0].name)
...@@ -25,17 +25,14 @@ namespace mace { ...@@ -25,17 +25,14 @@ namespace mace {
namespace {{tag}} { namespace {{tag}} {
const unsigned char *LoadModelData(const std::string &model_data_file); extern const unsigned char *LoadModelData();
const std::shared_ptr<NetDef> CreateNet(); extern const std::shared_ptr<NetDef> CreateNet();
const std::string ModelName(); extern const std::string ModelName();
extern const std::string ModelChecksum();
const std::string ModelChecksum(); extern const std::string ModelBuildTime();
extern const std::string ModelBuildOptions();
const std::string ModelBuildTime();
const std::string ModelBuildOptions();
} // namespace {{ tag }} } // namespace {{ tag }}
} // namespace mace } // namespace mace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册