diff --git a/mace/ops/cast.cc b/mace/ops/cast.cc index 9604bb90edf88a5d057e17f967a6e23447831d35..940959a93f0333033e26a0825f28cf0f735f1bb3 100644 --- a/mace/ops/cast.cc +++ b/mace/ops/cast.cc @@ -14,6 +14,10 @@ #include "mace/core/operator.h" +#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__) +#include +#endif + namespace mace { namespace ops { @@ -55,6 +59,10 @@ void RegisterCast(OpRegistryBase *op_registry) { DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "Cast", CastOp, DeviceType::CPU, int32_t); +#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__) + MACE_REGISTER_OP(op_registry, "Cast", CastOp, + DeviceType::CPU, float16_t); +#endif } } // namespace ops diff --git a/mace/ops/gather.cc b/mace/ops/gather.cc index 0c0551cd396af2279f47b245c371df4989143a98..2114290b66ff8d2d256bc7e9dcce02b298331112 100644 --- a/mace/ops/gather.cc +++ b/mace/ops/gather.cc @@ -93,6 +93,10 @@ void RegisterGather(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Gather", GatherOp, DeviceType::CPU, uint8_t); #endif // MACE_ENABLE_QUANTIZE +#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__) + MACE_REGISTER_OP(op_registry, "Gather", GatherOp, + DeviceType::CPU, float16_t); +#endif } } // namespace ops diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 871e02935eeef8ec0d1370c4070df28cb178133e..ff01f59790c9be6a57d766ae183deb5ac5754a64 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -315,6 +315,7 @@ class TransformerRule(Enum): UPDATE_DATA_FORMAT = 39 QUANTIZE_SPECIFIC_OPS_ONLY = 40 FP16_MATMUL_WEIGHT = 41 + FP16_GATHER_WEIGHT = 42 class ConverterInterface(object): diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 7e0327356e33360f322f7e6978a2ce370b3b03f9..cb095643de1c973a210fc8d3fc700c0ddd5a02eb 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -108,6 +108,8 @@ class Transformer(base_converter.ConverterInterface): self.quantize_specific_ops_only, TransformerRule.FP16_MATMUL_WEIGHT: self.fp16_matmul_weight, + TransformerRule.FP16_GATHER_WEIGHT: + self.fp16_gather_weight, } self._option = option @@ -1855,6 +1857,50 @@ class Transformer(base_converter.ConverterInterface): op.quantize_info[i].minval, op.quantize_info[i].maxval)) + def fp16_gather_weight(self): + for op in self._model.op: + if op.type != MaceOp.Gather.name: + continue + if op.input[0] not in self._consts: + raise KeyError("Not in const tensor: " + str(op.input[0])) + + const_tensor = self._consts[op.input[0]] + if const_tensor.data_type == mace_pb2.DT_FLOAT16: + print(str(const_tensor.name) + " is alreay float16") + continue + + print("FP16 Embedding Lookup Weights: %s" % const_tensor.name) + + op_outputs = [x for x in op.output] + new_gather_name = op.name + '_fp16' + new_gather_output_name = new_gather_name + ":0" + dehalve_name = op.name + + # fp16 weights + const_tensor.data_type = mace_pb2.DT_FLOAT16 + + # change gather + op.name = new_gather_name + op.output[:] = [new_gather_output_name] + # op.output.extend([new_gather_output_name]) + data_type_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_op_data_type_str) # noqa + if data_type_arg is None: + data_type_arg = op.arg.add() + data_type_arg.name = MaceKeyword.mace_op_data_type_str + data_type_arg.i = mace_pb2.DT_FLOAT16 + + # add dehalve + dehalve_op = self._model.op.add() + dehalve_op.name = dehalve_name + dehalve_op.type = MaceOp.Cast.name + dehalve_op.input.extend([new_gather_output_name]) + dehalve_op.output.extend(op_outputs) + dehalve_op.output_shape.extend(op.output_shape) + dehalve_op.output_type.extend([mace_pb2.DT_FLOAT]) + data_type_arg = dehalve_op.arg.add() + data_type_arg.name = MaceKeyword.mace_op_data_type_str + data_type_arg.i = mace_pb2.DT_FLOAT16 + def fp16_matmul_weight(self): if self._option.device != DeviceType.CPU.value: return