提交 92f6d362 编写于 作者: Y yulianfei 提交者: liukai6

add FP16_GATHER_WEIGHT

上级 00febe43
......@@ -14,6 +14,10 @@
#include "mace/core/operator.h"
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
#include <arm_neon.h>
#endif
namespace mace {
namespace ops {
......@@ -55,6 +59,10 @@ void RegisterCast(OpRegistryBase *op_registry) {
DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Cast", CastOp,
DeviceType::CPU, int32_t);
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
MACE_REGISTER_OP(op_registry, "Cast", CastOp,
DeviceType::CPU, float16_t);
#endif
}
} // namespace ops
......
......@@ -93,6 +93,10 @@ void RegisterGather(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Gather", GatherOp,
DeviceType::CPU, uint8_t);
#endif // MACE_ENABLE_QUANTIZE
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
MACE_REGISTER_OP(op_registry, "Gather", GatherOp,
DeviceType::CPU, float16_t);
#endif
}
} // namespace ops
......
......@@ -315,6 +315,7 @@ class TransformerRule(Enum):
UPDATE_DATA_FORMAT = 39
QUANTIZE_SPECIFIC_OPS_ONLY = 40
FP16_MATMUL_WEIGHT = 41
FP16_GATHER_WEIGHT = 42
class ConverterInterface(object):
......
......@@ -108,6 +108,8 @@ class Transformer(base_converter.ConverterInterface):
self.quantize_specific_ops_only,
TransformerRule.FP16_MATMUL_WEIGHT:
self.fp16_matmul_weight,
TransformerRule.FP16_GATHER_WEIGHT:
self.fp16_gather_weight,
}
self._option = option
......@@ -1855,6 +1857,50 @@ class Transformer(base_converter.ConverterInterface):
op.quantize_info[i].minval,
op.quantize_info[i].maxval))
def fp16_gather_weight(self):
for op in self._model.op:
if op.type != MaceOp.Gather.name:
continue
if op.input[0] not in self._consts:
raise KeyError("Not in const tensor: " + str(op.input[0]))
const_tensor = self._consts[op.input[0]]
if const_tensor.data_type == mace_pb2.DT_FLOAT16:
print(str(const_tensor.name) + " is alreay float16")
continue
print("FP16 Embedding Lookup Weights: %s" % const_tensor.name)
op_outputs = [x for x in op.output]
new_gather_name = op.name + '_fp16'
new_gather_output_name = new_gather_name + ":0"
dehalve_name = op.name
# fp16 weights
const_tensor.data_type = mace_pb2.DT_FLOAT16
# change gather
op.name = new_gather_name
op.output[:] = [new_gather_output_name]
# op.output.extend([new_gather_output_name])
data_type_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_op_data_type_str) # noqa
if data_type_arg is None:
data_type_arg = op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_FLOAT16
# add dehalve
dehalve_op = self._model.op.add()
dehalve_op.name = dehalve_name
dehalve_op.type = MaceOp.Cast.name
dehalve_op.input.extend([new_gather_output_name])
dehalve_op.output.extend(op_outputs)
dehalve_op.output_shape.extend(op.output_shape)
dehalve_op.output_type.extend([mace_pb2.DT_FLOAT])
data_type_arg = dehalve_op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_FLOAT16
def fp16_matmul_weight(self):
if self._option.device != DeviceType.CPU.value:
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册