提交 2dbd1eb7 编写于 作者: B Bin Li 提交者: 叶剑武

Refactor hexagon converter

上级 0c08cdf7
...@@ -60,41 +60,6 @@ HexagonSupportedOps = [ ...@@ -60,41 +60,6 @@ HexagonSupportedOps = [
HexagonOp = Enum('HexagonOp', [(op, op) for op in HexagonSupportedOps], HexagonOp = Enum('HexagonOp', [(op, op) for op in HexagonSupportedOps],
type=str) type=str)
class HexagonOps(object):
def __init__(self):
self.hexagon_ops = {
MaceOp.BatchToSpaceND.name: HexagonOp.BatchToSpaceND_8.name,
MaceOp.DepthToSpace.name: HexagonOp.DepthToSpace_8.name,
MaceOp.Concat.name: HexagonOp.QuantizedConcat_8.name,
MaceOp.Conv2D.name: HexagonOp.Supernode_8x8p32to8.name,
MaceOp.DepthwiseConv2d.name:
HexagonOp.DepthwiseSupernode_8x8p32to8.name,
MaceOp.Dequantize.name: HexagonOp.DequantizeOUTPUT_8tof.name,
MaceOp.Eltwise.name: [HexagonOp.QuantizedAdd_8p8to8.name,
HexagonOp.QuantizedSub_8p8to8.name,
HexagonOp.QuantizedMul_8x8to8.name],
MaceOp.Identity.name: HexagonOp.Nop.name,
MaceOp.Quantize.name: HexagonOp.QuantizeINPUT_f_to_8.name,
MaceOp.Pooling.name: [HexagonOp.QuantizedAvgPool_8.name,
HexagonOp.QuantizedMaxPool_8.name],
MaceOp.Reduce.name: HexagonOp.QuantizedAvgPool_8.name,
MaceOp.ResizeBilinear.name:
HexagonOp.QuantizedResizeBilinear_8.name,
MaceOp.SpaceToBatchND.name: HexagonOp.SpaceToBatchND_8.name,
MaceOp.SpaceToDepth.name: HexagonOp.SpaceToDepth_8.name,
MaceOp.Softmax.name: HexagonOp.QuantizedSoftmax_8.name,
}
def has_op(self, tf_op):
return tf_op in self.hexagon_ops
def map_nn_op(self, tf_op):
if tf_op not in self.hexagon_ops:
raise Exception('Could not map nn op for: ', tf_op)
return self.hexagon_ops[tf_op]
padding_mode = { padding_mode = {
PaddingMode.NA: 0, PaddingMode.NA: 0,
PaddingMode.SAME: 1, PaddingMode.SAME: 1,
...@@ -124,9 +89,24 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -124,9 +89,24 @@ class HexagonConverter(base_converter.ConverterInterface):
def __init__(self, option, model, quantize_activation_info): def __init__(self, option, model, quantize_activation_info):
self._option = option self._option = option
self._model = model self._model = model
self._hexagon_ops = HexagonOps()
self._consts = {} self._consts = {}
self._quantize_activation_info = quantize_activation_info self._quantize_activation_info = quantize_activation_info
self._op_converters = {
MaceOp.BatchToSpaceND.name: self.convert_batchspace,
MaceOp.Concat.name: self.convert_concat,
MaceOp.Conv2D.name: self.convert_conv2d,
MaceOp.DepthToSpace.name: self.convert_depthspace,
MaceOp.DepthwiseConv2d.name: self.convert_conv2d,
MaceOp.Dequantize.name: self.convert_dequantize,
MaceOp.Eltwise.name: self.convert_elementwise,
MaceOp.Pooling.name: self.convert_pooling,
MaceOp.Quantize.name: self.convert_quantize,
MaceOp.Reduce.name: self.convert_reduce,
MaceOp.ResizeBilinear.name: self.convert_resizebilinear,
MaceOp.Softmax.name: self.convert_softmax,
MaceOp.SpaceToBatchND.name: self.convert_batchspace,
MaceOp.SpaceToDepth.name: self.convert_depthspace,
}
def run(self): def run(self):
if self._option.device == DeviceType.HTA.value: if self._option.device == DeviceType.HTA.value:
...@@ -155,221 +135,6 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -155,221 +135,6 @@ class HexagonConverter(base_converter.ConverterInterface):
self._quantize_activation_info[tensors[i]] = \ self._quantize_activation_info[tensors[i]] = \
self._quantize_activation_info[node_name] self._quantize_activation_info[node_name]
def convert_ops(self):
print("Convert mace graph to hexagon.")
for op in self._model.op:
if not self._hexagon_ops.has_op(op.type):
raise Exception('Unsupported op: ', op)
self.add_port_for_tensors(op.input)
self.add_port_for_tensors(op.output)
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
channels = op.output_shape[0].dims[3]
if len(op.input) < 3:
print('Supernode requires biasadd, we add it.')
bias_data = np.zeros(channels, dtype=int)
bias_tensor = self._model.tensors.add()
bias_tensor.data_type = mace_pb2.DT_INT32
bias_tensor.dims.extend([channels])
bias_tensor.int32_data.extend(bias_data)
bias_tensor.minval = 0
bias_tensor.maxval = 0
bias_tensor.name = op.name + "/bias:0"
bias = bias_tensor.name
self._consts[bias] = bias_tensor
else:
bias = op.input.pop()
self.add_min_max_const_node(op, op.input[0])
self.add_min_max_const_node(op, op.input[1])
strides_arg = ConverterUtil.get_arg(op, 'strides')
mace_check(strides_arg is not None,
"Missing strides of Conv or Depthwise Conv.")
strides = self.add_shape_const_node(
op, [1, strides_arg.ints[0], strides_arg.ints[1], 1],
MaceKeyword.mace_strides_str)
op.input.extend([strides, bias])
self.add_min_max_const_node(op, bias)
self.add_min_max_const_node(
op, op.output[0], True, True, False)
elif op.type == MaceOp.Eltwise.name:
self.add_min_max_const_node(op, op.input[0])
self.add_min_max_const_node(op, op.input[1])
element_type = \
ConverterUtil.get_arg(op,
MaceKeyword.mace_element_type_str).i
if element_type == EltwiseType.SUM.value \
or element_type == EltwiseType.SUB.value:
self.add_min_max_const_node(
op, op.output[0], True, True, False)
elif op.type == MaceOp.BatchToSpaceND.name \
or op.type == MaceOp.SpaceToBatchND.name:
strides_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_space_batch_block_shape_str)
strides_tensor = self._model.tensors.add()
strides_tensor.name = op.name + '/strides:0'
strides_tensor.data_type = mace_pb2.DT_INT32
strides_tensor.dims.extend([1, 1, 1, len(strides_arg.ints)])
strides_tensor.int32_data.extend(strides_arg.ints)
if op.type == MaceOp.BatchToSpaceND.name:
pad_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_batch_to_space_crops_str)
else:
pad_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_paddings_str)
pad_tensor = self._model.tensors.add()
pad_tensor.name = op.name + '/pad:0'
pad_tensor.data_type = mace_pb2.DT_INT32
pad_tensor.dims.extend([1, 1, len(pad_arg.ints) // 2, 2])
pad_tensor.int32_data.extend(pad_arg.ints)
op.input.extend([strides_tensor.name, pad_tensor.name])
self.add_min_max_const_node(op, op.input[0])
elif op.type == MaceOp.DepthToSpace.name \
or op.type == MaceOp.SpaceToDepth.name:
size_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_space_depth_block_size_str)
size_tensor = self._model.tensors.add()
size_tensor.name = op.name + '/block_size:0'
size_tensor.data_type = mace_pb2.DT_INT32
size_tensor.dims.extend([1])
size_tensor.int32_data.extend([size_arg.i])
op.input.extend([size_tensor.name])
self.add_min_max_const_node(op, op.input[0])
elif op.type == MaceOp.Pooling.name:
self.add_min_max_const_node(op, op.input[0])
window_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_kernel_str)
window_tensor = self._model.tensors.add()
window_tensor.name = op.name + '/window:0'
window_tensor.data_type = mace_pb2.DT_INT32
window_tensor.dims.extend(
[1, window_arg.ints[0], window_arg.ints[1], 1])
strides_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_strides_str)
strides_tensor = self._model.tensors.add()
strides_tensor.name = op.name + '/strides:0'
strides_tensor.data_type = mace_pb2.DT_INT32
strides_tensor.dims.extend(
[1, strides_arg.ints[0], strides_arg.ints[1], 1])
op.input.extend([window_tensor.name, strides_tensor.name])
elif op.type == MaceOp.Reduce.name:
self.add_min_max_const_node(op, op.input[0])
reduce_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_reduce_type_str)
mace_check(reduce_type_arg.i == ReduceType.MEAN.value,
"Hexagon Reduce only supports Mean now.")
keep_dims_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_keepdims_str)
mace_check(keep_dims_arg.i == 1,
"Hexagon Reduce Mean only supports keep dims now.")
axis_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
mace_check(1 <= len(axis_arg.ints) <= 2,
"Hexagon Reduce Mean only supports spatial now.")
for i in axis_arg.ints:
mace_check(1 <= i <= 2,
"Hexagon Reduce Mean only supports spatial now")
producer_op_name, _ = get_op_and_port_from_tensor(op.input[0])
input_dims = None
for producer_op in self._model.op:
if producer_op.name == producer_op_name:
input_dims = producer_op.output_shape[0].dims
break
mace_check(input_dims is not None, "Missing input shape.")
window_tensor = self._model.tensors.add()
window_tensor.name = op.name + '/window:0'
window_tensor.data_type = mace_pb2.DT_INT32
if len(axis_arg.ints) == 1:
dim1, dim2 = (input_dims[1], 1) \
if axis_arg.ints[0] == 1 else (1, input_dims[2])
else:
dim1, dim2 = input_dims[1], input_dims[2]
window_tensor.dims.extend([1, dim1, dim2, 1])
strides_tensor = self._model.tensors.add()
strides_tensor.name = op.name + '/strides:0'
strides_tensor.data_type = mace_pb2.DT_INT32
strides_tensor.dims.extend([1, dim1, dim2, 1])
op.input.extend([window_tensor.name, strides_tensor.name])
elif op.type == MaceOp.ResizeBilinear.name:
newdim_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_resize_size_str)
newdim_tensor = self._model.tensors.add()
newdim_tensor.name = op.name + '/newdim:0'
newdim_tensor.data_type = mace_pb2.DT_INT32
newdim_tensor.dims.extend([len(newdim_arg.ints)])
newdim_tensor.int32_data.extend(newdim_arg.ints)
op.input.extend([newdim_tensor.name])
self.add_min_max_const_node(op, op.input[0])
align_corners_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_align_corners_str)
align_corners_tensor = self._model.tensors.add()
align_corners_tensor.name = op.name + '/align_corners:0'
align_corners_tensor.data_type = mace_pb2.DT_INT32
align_corners_tensor.dims.extend([1])
align_corners_tensor.int32_data.extend([align_corners_arg.i])
op.input.extend([align_corners_tensor.name])
elif op.type == MaceOp.Concat.name:
inputs = copy.deepcopy(op.input)
for ipt in inputs:
self.add_min_max_const_node(op, ipt, True, False)
for ipt in inputs:
self.add_min_max_const_node(op, ipt, False, True)
dim_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
dim_tensor = self._model.tensors.add()
dim_tensor.name = op.name + '/dim:0'
dim_tensor.data_type = mace_pb2.DT_INT32
dim_tensor.dims.extend([1])
dim_tensor.int32_data.extend([dim_arg.i])
op.input.insert(0, dim_tensor.name)
elif op.type in [MaceOp.Softmax.name,
MaceOp.Dequantize.name]:
self.add_min_max_const_node(op, op.input[0])
if op.type != MaceOp.Dequantize.name:
min_output_shape = op.output_shape.add()
min_output_shape.dims.extend([1])
max_output_shape = op.output_shape.add()
max_output_shape.dims.extend([1])
op.output_type.extend(
[mace_pb2.DT_UINT8, mace_pb2.DT_FLOAT, mace_pb2.DT_FLOAT])
for i in range(len(op.output_shape)):
out_max_byte_size = reduce(mul, op.output_shape[i].dims)
if op.output_type[i] == mace_pb2.DT_FLOAT:
out_max_byte_size *= 4
op.out_max_byte_size.extend([out_max_byte_size])
op.padding = padding_mode[PaddingMode.NA]
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_padding_str)
if arg is not None:
op.padding = padding_mode[PaddingMode(arg.i)]
if op.type == MaceOp.Eltwise.name:
element_type = \
ConverterUtil.get_arg(op,
MaceKeyword.mace_element_type_str).i
if element_type == EltwiseType.SUM.value:
op.type = HexagonOp.QuantizedAdd_8p8to8.name
elif element_type == EltwiseType.SUB.value:
op.type = HexagonOp.QuantizedSub_8p8to8.name
elif element_type == EltwiseType.PROD.value:
op.type = HexagonOp.QuantizedMul_8x8to8.name
else:
mace_check(False,
"Hexagon does not support elementwise %s"
% EltwiseType(element_type).name)
elif op.type == MaceOp.Pooling.name:
pooling_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_pooling_type_str)
if PoolingType(pooling_type_arg.i) == PoolingType.AVG:
op.type = HexagonOp.QuantizedAvgPool_8.name
else:
op.type = HexagonOp.QuantizedMaxPool_8.name
else:
op.type = self._hexagon_ops.map_nn_op(op.type)
def add_const_node(self, name, val): def add_const_node(self, name, val):
if name not in self._consts: if name not in self._consts:
tensor = self._model.tensors.add() tensor = self._model.tensors.add()
...@@ -379,6 +144,18 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -379,6 +144,18 @@ class HexagonConverter(base_converter.ConverterInterface):
tensor.dims.extend([1]) tensor.dims.extend([1])
tensor.float_data.extend([val]) tensor.float_data.extend([val])
def add_arg_const_node(self, op, name, dims, data=None, insert_index=None):
arg_tensor = self._model.tensors.add()
arg_tensor.name = op.name + name
arg_tensor.data_type = mace_pb2.DT_INT32
arg_tensor.dims.extend(dims)
if data:
arg_tensor.int32_data.extend(data)
if insert_index is None:
op.input.append(arg_tensor.name)
else:
op.input.insert(insert_index, arg_tensor.name)
def add_min_max_const_node( def add_min_max_const_node(
self, this_op, tensor_name, add_min=True, add_max=True, self, this_op, tensor_name, add_min=True, add_max=True,
diff_port=True): diff_port=True):
...@@ -412,14 +189,6 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -412,14 +189,6 @@ class HexagonConverter(base_converter.ConverterInterface):
self.add_const_node(max_tensor_name, maxval) self.add_const_node(max_tensor_name, maxval)
this_op.input.extend([max_tensor_name]) this_op.input.extend([max_tensor_name])
def add_shape_const_node(self, op, values, name):
tensor = self._model.tensors.add()
node_name = op.name + '/' + name
tensor.name = node_name + ':0'
tensor.data_type = mace_pb2.DT_INT32
tensor.dims.extend(values)
return tensor.name
def add_constant_min_max_for_first_op(self, op): def add_constant_min_max_for_first_op(self, op):
minval = self._quantize_activation_info[op.input[0]].minval minval = self._quantize_activation_info[op.input[0]].minval
maxval = self._quantize_activation_info[op.input[0]].maxval maxval = self._quantize_activation_info[op.input[0]].maxval
...@@ -536,3 +305,222 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -536,3 +305,222 @@ class HexagonConverter(base_converter.ConverterInterface):
node_input = op.node_input.add() node_input = op.node_input.add()
node_input.node_id = node_id node_input.node_id = node_id
node_input.output_port = int(port) node_input.output_port = int(port)
def convert_ops(self):
print("Convert mace graph to hexagon.")
for op in self._model.op:
mace_check(op.type in self._op_converters,
"Mace Hexagon does not support op type %s yet"
% op.type)
self.pre_convert(op)
self._op_converters[op.type](op)
self.post_convert(op)
def pre_convert(self, op):
self.add_port_for_tensors(op.input)
self.add_port_for_tensors(op.output)
def post_convert(self, op):
if op.type != MaceOp.Dequantize.name:
min_output_shape = op.output_shape.add()
min_output_shape.dims.extend([1])
max_output_shape = op.output_shape.add()
max_output_shape.dims.extend([1])
op.output_type.extend(
[mace_pb2.DT_UINT8, mace_pb2.DT_FLOAT, mace_pb2.DT_FLOAT])
for i in range(len(op.output_shape)):
out_max_byte_size = reduce(mul, op.output_shape[i].dims)
if op.output_type[i] == mace_pb2.DT_FLOAT:
out_max_byte_size *= 4
op.out_max_byte_size.extend([out_max_byte_size])
op.padding = padding_mode[PaddingMode.NA]
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_padding_str)
if arg is not None:
op.padding = padding_mode[PaddingMode(arg.i)]
def convert_batchspace(self, op):
strides_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_space_batch_block_shape_str)
self.add_arg_const_node(
op, '/strides:0', [1, 1, 1, len(strides_arg.ints)],
strides_arg.ints)
if op.type == MaceOp.BatchToSpaceND.name:
pad_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_batch_to_space_crops_str)
else:
pad_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_paddings_str)
self.add_arg_const_node(
op, '/pad:0', [1, 1, len(pad_arg.ints) // 2, 2], pad_arg.ints)
self.add_min_max_const_node(op, op.input[0])
if op.type == MaceOp.BatchToSpaceND.name:
op.type = HexagonOp.BatchToSpaceND_8.name
else:
op.type = HexagonOp.SpaceToBatchND_8.name
def convert_concat(self, op):
inputs = copy.deepcopy(op.input)
for ipt in inputs:
self.add_min_max_const_node(op, ipt, True, False)
for ipt in inputs:
self.add_min_max_const_node(op, ipt, False, True)
dim_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
self.add_arg_const_node(op, '/dim:0', [1], [dim_arg.i], 0)
op.type = HexagonOp.QuantizedConcat_8.name
def convert_conv2d(self, op):
channels = op.output_shape[0].dims[3]
if len(op.input) < 3:
print('Supernode requires biasadd, we add it.')
bias_data = np.zeros(channels, dtype=int)
bias_tensor = self._model.tensors.add()
bias_tensor.data_type = mace_pb2.DT_INT32
bias_tensor.dims.extend([channels])
bias_tensor.int32_data.extend(bias_data)
bias_tensor.minval = 0
bias_tensor.maxval = 0
bias_tensor.name = op.name + "/bias:0"
bias = bias_tensor.name
self._consts[bias] = bias_tensor
else:
bias = op.input.pop()
self.add_min_max_const_node(op, op.input[0])
self.add_min_max_const_node(op, op.input[1])
strides_arg = ConverterUtil.get_arg(op, 'strides')
mace_check(strides_arg is not None,
"Missing strides of Conv or Depthwise Conv.")
self.add_arg_const_node(
op, '/strides:0', [1, strides_arg.ints[0], strides_arg.ints[1], 1])
op.input.append(bias)
self.add_min_max_const_node(op, bias)
self.add_min_max_const_node(
op, op.output[0], True, True, False)
if op.type == MaceOp.DepthwiseConv2d.name:
op.type = HexagonOp.DepthwiseSupernode_8x8p32to8.name
else:
op.type = HexagonOp.Supernode_8x8p32to8.name
def convert_depthspace(self, op):
size_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_space_depth_block_size_str)
self.add_arg_const_node(op, '/block_size:0', [1], [size_arg.i])
self.add_min_max_const_node(op, op.input[0])
if op.type == MaceOp.DepthToSpace.name:
op.type = HexagonOp.DepthToSpace_8.name
else:
op.type = HexagonOp.SpaceToDepth_8.name
def convert_dequantize(self, op):
self.add_min_max_const_node(op, op.input[0])
op.type = HexagonOp.DequantizeOUTPUT_8tof.name
def convert_elementwise(self, op):
self.add_min_max_const_node(op, op.input[0])
self.add_min_max_const_node(op, op.input[1])
element_type = \
ConverterUtil.get_arg(op,
MaceKeyword.mace_element_type_str).i
if element_type == EltwiseType.SUM.value:
self.add_min_max_const_node(
op, op.output[0], True, True, False)
op.type = HexagonOp.QuantizedAdd_8p8to8.name
elif element_type == EltwiseType.SUB.value:
self.add_min_max_const_node(
op, op.output[0], True, True, False)
op.type = HexagonOp.QuantizedSub_8p8to8.name
elif element_type == EltwiseType.PROD.value:
op.type = HexagonOp.QuantizedMul_8x8to8.name
else:
mace_check(False,
"Hexagon does not support elementwise %s"
% EltwiseType(element_type).name)
def convert_pooling(self, op):
self.add_min_max_const_node(op, op.input[0])
window_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_kernel_str)
self.add_arg_const_node(
op, '/window:0', [1, window_arg.ints[0], window_arg.ints[1], 1])
strides_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_strides_str)
self.add_arg_const_node(
op, '/strides:0', [1, strides_arg.ints[0], strides_arg.ints[1], 1])
pooling_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_pooling_type_str)
if PoolingType(pooling_type_arg.i) == PoolingType.AVG:
op.type = HexagonOp.QuantizedAvgPool_8.name
else:
op.type = HexagonOp.QuantizedMaxPool_8.name
def convert_quantize(self, op):
op.type = HexagonOp.QuantizeINPUT_f_to_8.name
def convert_reduce(self, op):
self.add_min_max_const_node(op, op.input[0])
reduce_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_reduce_type_str)
mace_check(reduce_type_arg.i == ReduceType.MEAN.value,
"Hexagon Reduce only supports Mean now.")
keep_dims_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_keepdims_str)
mace_check(keep_dims_arg.i == 1,
"Hexagon Reduce Mean only supports keep dims now.")
axis_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
mace_check(1 <= len(axis_arg.ints) <= 2,
"Hexagon Reduce Mean only supports spatial now.")
for i in axis_arg.ints:
mace_check(1 <= i <= 2,
"Hexagon Reduce Mean only supports spatial now")
producer_op_name, _ = get_op_and_port_from_tensor(op.input[0])
input_dims = None
for producer_op in self._model.op:
if producer_op.name == producer_op_name:
input_dims = producer_op.output_shape[0].dims
break
mace_check(input_dims is not None, "Missing input shape.")
if len(axis_arg.ints) == 1:
dim1, dim2 = (input_dims[1], 1) \
if axis_arg.ints[0] == 1 else (1, input_dims[2])
else:
dim1, dim2 = input_dims[1], input_dims[2]
self.add_arg_const_node(op, '/window:0', [1, dim1, dim2, 1])
self.add_arg_const_node(op, '/strides:0', [1, dim1, dim2, 1])
op.type = HexagonOp.QuantizedAvgPool_8.name
def convert_resizebilinear(self, op):
newdim_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_resize_size_str)
self.add_arg_const_node(
op, '/newdim:0', [len(newdim_arg.ints)], newdim_arg.ints)
self.add_min_max_const_node(op, op.input[0])
align_corners_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_align_corners_str)
self.add_arg_const_node(
op, '/align_corners:0', [1], [align_corners_arg.i])
op.type = HexagonOp.QuantizedResizeBilinear_8.name
def convert_softmax(self, op):
self.add_min_max_const_node(op, op.input[0])
op.type = HexagonOp.QuantizedSoftmax_8.name
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册