提交 92f6d362 编写于 作者: Y yulianfei 提交者: liukai6

add FP16_GATHER_WEIGHT

上级 00febe43
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "mace/core/operator.h" #include "mace/core/operator.h"
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
#include <arm_neon.h>
#endif
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -55,6 +59,10 @@ void RegisterCast(OpRegistryBase *op_registry) { ...@@ -55,6 +59,10 @@ void RegisterCast(OpRegistryBase *op_registry) {
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Cast", CastOp, MACE_REGISTER_OP(op_registry, "Cast", CastOp,
DeviceType::CPU, int32_t); DeviceType::CPU, int32_t);
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
MACE_REGISTER_OP(op_registry, "Cast", CastOp,
DeviceType::CPU, float16_t);
#endif
} }
} // namespace ops } // namespace ops
......
...@@ -93,6 +93,10 @@ void RegisterGather(OpRegistryBase *op_registry) { ...@@ -93,6 +93,10 @@ void RegisterGather(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Gather", GatherOp, MACE_REGISTER_OP(op_registry, "Gather", GatherOp,
DeviceType::CPU, uint8_t); DeviceType::CPU, uint8_t);
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
MACE_REGISTER_OP(op_registry, "Gather", GatherOp,
DeviceType::CPU, float16_t);
#endif
} }
} // namespace ops } // namespace ops
......
...@@ -315,6 +315,7 @@ class TransformerRule(Enum): ...@@ -315,6 +315,7 @@ class TransformerRule(Enum):
UPDATE_DATA_FORMAT = 39 UPDATE_DATA_FORMAT = 39
QUANTIZE_SPECIFIC_OPS_ONLY = 40 QUANTIZE_SPECIFIC_OPS_ONLY = 40
FP16_MATMUL_WEIGHT = 41 FP16_MATMUL_WEIGHT = 41
FP16_GATHER_WEIGHT = 42
class ConverterInterface(object): class ConverterInterface(object):
......
...@@ -108,6 +108,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -108,6 +108,8 @@ class Transformer(base_converter.ConverterInterface):
self.quantize_specific_ops_only, self.quantize_specific_ops_only,
TransformerRule.FP16_MATMUL_WEIGHT: TransformerRule.FP16_MATMUL_WEIGHT:
self.fp16_matmul_weight, self.fp16_matmul_weight,
TransformerRule.FP16_GATHER_WEIGHT:
self.fp16_gather_weight,
} }
self._option = option self._option = option
...@@ -1855,6 +1857,50 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1855,6 +1857,50 @@ class Transformer(base_converter.ConverterInterface):
op.quantize_info[i].minval, op.quantize_info[i].minval,
op.quantize_info[i].maxval)) op.quantize_info[i].maxval))
def fp16_gather_weight(self):
for op in self._model.op:
if op.type != MaceOp.Gather.name:
continue
if op.input[0] not in self._consts:
raise KeyError("Not in const tensor: " + str(op.input[0]))
const_tensor = self._consts[op.input[0]]
if const_tensor.data_type == mace_pb2.DT_FLOAT16:
print(str(const_tensor.name) + " is alreay float16")
continue
print("FP16 Embedding Lookup Weights: %s" % const_tensor.name)
op_outputs = [x for x in op.output]
new_gather_name = op.name + '_fp16'
new_gather_output_name = new_gather_name + ":0"
dehalve_name = op.name
# fp16 weights
const_tensor.data_type = mace_pb2.DT_FLOAT16
# change gather
op.name = new_gather_name
op.output[:] = [new_gather_output_name]
# op.output.extend([new_gather_output_name])
data_type_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_op_data_type_str) # noqa
if data_type_arg is None:
data_type_arg = op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_FLOAT16
# add dehalve
dehalve_op = self._model.op.add()
dehalve_op.name = dehalve_name
dehalve_op.type = MaceOp.Cast.name
dehalve_op.input.extend([new_gather_output_name])
dehalve_op.output.extend(op_outputs)
dehalve_op.output_shape.extend(op.output_shape)
dehalve_op.output_type.extend([mace_pb2.DT_FLOAT])
data_type_arg = dehalve_op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_FLOAT16
def fp16_matmul_weight(self): def fp16_matmul_weight(self):
if self._option.device != DeviceType.CPU.value: if self._option.device != DeviceType.CPU.value:
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册