diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 8c272ac2442db86b33e9617d7cf9e05f1ebb477d..fca6ca95b335f030cbf7ebfe92a376ed1512e2c7 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -96,6 +96,7 @@ TFSupportedOps = [ 'Pack', 'Cast', 'ArgMax', + 'Split', ] TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) @@ -192,6 +193,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Stack.name: self.convert_stack, TFOpType.Cast.name: self.convert_cast, TFOpType.ArgMax.name: self.convert_argmax, + TFOpType.Split.name: self.convert_split, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -759,3 +761,21 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.ArgMax.name op.output_type.extend([mace_pb2.DT_INT32]) + + def convert_split(self, tf_op): + # inputs: [dim, input] + axis = tf_op.inputs[0].eval().astype(np.int32) + axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis + mace_check(axis == 3, 'Split with %d axis only support' % axis) + input_shape = self.infer_tensor_shape(tf_op.inputs[1]) + mace_check(len(input_shape) == 4 and (input_shape[3] % 4 == 0), + "The input's 4th dimension should be a multiple of 4") + op = self.convert_general_op(tf_op) + op.type = MaceOp.Slice.name + del op.input[0] + + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = axis + + self._skip_tensor.add(tf_op.inputs[0].name) diff --git a/mace/python/tools/model_header.jinja2 b/mace/python/tools/model_header.jinja2 index c90c4f08b266e209fecf37fd5969efb73a8f0e05..2b7dd5fb036def99dd7f05884cfe5bb6e74849b2 100644 --- a/mace/python/tools/model_header.jinja2 +++ b/mace/python/tools/model_header.jinja2 @@ -25,17 +25,14 @@ namespace mace { namespace {{tag}} { -const unsigned char *LoadModelData(const std::string &model_data_file); +extern const unsigned char *LoadModelData(); -const std::shared_ptr CreateNet(); +extern const std::shared_ptr CreateNet(); -const std::string ModelName(); - -const std::string ModelChecksum(); - -const std::string ModelBuildTime(); - -const std::string ModelBuildOptions(); +extern const std::string ModelName(); +extern const std::string ModelChecksum(); +extern const std::string ModelBuildTime(); +extern const std::string ModelBuildOptions(); } // namespace {{ tag }} } // namespace mace