提交 d2c2897a 编写于 作者: L liyin

Quantize matmul only, add gather u8 test.

上级 0c3cc381
...@@ -926,7 +926,9 @@ class EltwiseOp : public Operation { ...@@ -926,7 +926,9 @@ class EltwiseOp : public Operation {
const Tensor *input1, const Tensor *input1,
Tensor *output) { Tensor *output) {
bool swapped = false; bool swapped = false;
if (input0->size() < input1->size()) { if (input0->dim_size() < input1->dim_size()
|| (input0->dim_size() == input1->dim_size()
&& input0->size() < input1->size())) {
std::swap(input0, input1); std::swap(input0, input1);
swapped = true; swapped = true;
} }
......
...@@ -23,53 +23,67 @@ namespace test { ...@@ -23,53 +23,67 @@ namespace test {
class GatherOpTest : public OpsTestBase {}; class GatherOpTest : public OpsTestBase {};
namespace { namespace {
template<typename T>
void TestGather(const std::vector<index_t> &weight_shape, void TestGather(const std::vector<index_t> &weight_shape,
const std::vector<float> &weight, const std::vector<T> &weight,
const std::vector<index_t> &input_shape, const std::vector<index_t> &input_shape,
const std::vector<int32_t> &input, const std::vector<int32_t> &input,
const int axis, const int axis,
const std::vector<index_t> &output_shape, const std::vector<index_t> &output_shape,
const std::vector<float> &output) { const std::vector<T> &output) {
OpsTestNet net; OpsTestNet net;
net.AddInputFromArray<CPU, float>("Params", weight_shape, weight); net.AddInputFromArray<CPU, T>("Params", weight_shape, weight);
net.AddInputFromArray<CPU, int32_t>("Indices", input_shape, input); net.AddInputFromArray<CPU, int32_t>("Indices", input_shape, input);
OpDefBuilder("Gather", "GatherTest") OpDefBuilder("Gather", "GatherTest")
.Input("Params") .Input("Params")
.Input("Indices") .Input("Indices")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("axis", axis) .AddIntArg("axis", axis)
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(CPU); net.RunOp(CPU);
auto expected = net.CreateTensor<float>(output_shape, output); auto expected = net.CreateTensor<T>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<T>(*expected, *net.GetOutput("Output"), 1e-5);
} }
} // namespace } // namespace
TEST_F(GatherOpTest, CPUScalarIndex) { TEST_F(GatherOpTest, CPUScalarIndex) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, TestGather<float>({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{}, {5}, 0, {2}, {10, 11});
TestGather<uint8_t>({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{}, {5}, 0, {2}, {10, 11}); {}, {5}, 0, {2}, {10, 11});
} }
TEST_F(GatherOpTest, CPURank1Index) { TEST_F(GatherOpTest, CPURank1Index) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, TestGather<float>({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{3}, {2, 4, 6}, 0, {3, 2}, {4, 5, 8, 9, 12, 13});
TestGather<uint8_t>({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{3}, {2, 4, 6}, 0, {3, 2}, {4, 5, 8, 9, 12, 13}); {3}, {2, 4, 6}, 0, {3, 2}, {4, 5, 8, 9, 12, 13});
} }
TEST_F(GatherOpTest, CPURank1IndexWithAxis1) { TEST_F(GatherOpTest, CPURank1IndexWithAxis1) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, TestGather<float>({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{1}, {1}, 1, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19});
TestGather<uint8_t>({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{1}, {1}, 1, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19}); {1}, {1}, 1, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19});
} }
TEST_F(GatherOpTest, CPURankHighIndex) { TEST_F(GatherOpTest, CPURankHighIndex) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, TestGather<float>({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{1, 3}, {2, 4, 6}, 0, {1, 3, 2}, {4, 5, 8, 9, 12, 13});
TestGather<uint8_t>({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{1, 3}, {2, 4, 6}, 0, {1, 3, 2}, {4, 5, 8, 9, 12, 13}); {1, 3}, {2, 4, 6}, 0, {1, 3, 2}, {4, 5, 8, 9, 12, 13});
} }
......
...@@ -233,8 +233,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> { ...@@ -233,8 +233,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
const index_t height, const index_t height,
const index_t K, const index_t K,
const index_t width, const index_t width,
const bool lhs_bached, const bool lhs_batched,
const bool rhs_bached, const bool rhs_batched,
Tensor *C) { Tensor *C) {
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) {
...@@ -245,8 +245,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> { ...@@ -245,8 +245,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
batch, batch,
height, height,
K, K,
true, lhs_batched,
true, rhs_batched,
C); C);
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
gemv_kernel_.Compute(context, gemv_kernel_.Compute(context,
...@@ -256,8 +256,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> { ...@@ -256,8 +256,8 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
batch, batch,
width, width,
K, K,
true, lhs_batched,
true, rhs_batched,
C); C);
} else { } else {
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
...@@ -281,11 +281,13 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> { ...@@ -281,11 +281,13 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
for (index_t i = 0; i < batch; ++i) { for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder> gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + static_cast<index_t>(lhs_bached) * i * a_size, a_matrix(a_ptr_base
+ static_cast<index_t>(lhs_batched) * i * a_size,
height, height,
K); K);
gemmlowp::MatrixMap<const uint8_t, BOrder> gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + static_cast<index_t>(rhs_bached) * i * b_size, b_matrix(b_ptr_base
+ static_cast<index_t>(rhs_batched) * i * b_size,
K, K,
width); width);
gemmlowp::MatrixMap <uint8_t, gemmlowp::MapOrder::RowMajor> gemmlowp::MatrixMap <uint8_t, gemmlowp::MapOrder::RowMajor>
...@@ -315,8 +317,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -315,8 +317,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
const index_t height, const index_t height,
const index_t K, const index_t K,
const index_t width, const index_t width,
const bool lhs_bached, const bool lhs_batched,
const bool rhs_bached, const bool rhs_batched,
Tensor *C) { Tensor *C) {
C->SetScale(A->scale() * B->scale()); C->SetScale(A->scale() * B->scale());
C->SetZeroPoint(0); C->SetZeroPoint(0);
...@@ -330,8 +332,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -330,8 +332,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
batch, batch,
height, height,
K, K,
lhs_bached, lhs_batched,
rhs_bached, rhs_batched,
C); C);
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
gemv_kernel_.Compute(context, gemv_kernel_.Compute(context,
...@@ -341,8 +343,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -341,8 +343,8 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
batch, batch,
width, width,
K, K,
lhs_bached, lhs_batched,
rhs_bached, rhs_batched,
C); C);
} else { } else {
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
...@@ -366,12 +368,12 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> { ...@@ -366,12 +368,12 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
for (index_t i = 0; i < batch; ++i) { for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder> gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix a_matrix
(a_ptr_base + static_cast<index_t>(lhs_bached) * i * a_size, (a_ptr_base + static_cast<index_t>(lhs_batched) * i * a_size,
height, height,
K); K);
gemmlowp::MatrixMap<const uint8_t, BOrder> gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix b_matrix
(b_ptr_base + static_cast<index_t>(rhs_bached) * i * b_size, (b_ptr_base + static_cast<index_t>(rhs_batched) * i * b_size,
K, K,
width); width);
gemmlowp::MatrixMap <int32_t, gemmlowp::MapOrder::RowMajor> gemmlowp::MatrixMap <int32_t, gemmlowp::MapOrder::RowMajor>
......
...@@ -135,7 +135,8 @@ void Complex(const std::vector<index_t> &batch, ...@@ -135,7 +135,8 @@ void Complex(const std::vector<index_t> &batch,
rhs_batched, rhs_batched,
&expected_output_tensor); &expected_output_tensor);
ExpectTensorNear<float>(expected_output_tensor, *net.GetTensor("Output")); ExpectTensorNear<float>(expected_output_tensor, *net.GetTensor("Output"),
1e-4, 1e-2);
} }
} // namespace } // namespace
......
...@@ -236,6 +236,7 @@ class MaceKeyword(object): ...@@ -236,6 +236,7 @@ class MaceKeyword(object):
mace_step_h_str = 'step_h' mace_step_h_str = 'step_h'
mace_step_w_str = 'step_w' mace_step_w_str = 'step_w'
mace_find_range_every_time = 'find_range_every_time' mace_find_range_every_time = 'find_range_every_time'
mace_non_zero = 'non_zero'
mace_pad_type_str = 'pad_type' mace_pad_type_str = 'pad_type'
...@@ -279,6 +280,7 @@ class TransformerRule(Enum): ...@@ -279,6 +280,7 @@ class TransformerRule(Enum):
FOLD_FC_RESHAPE = 37 FOLD_FC_RESHAPE = 37
TRANSFORM_CHANNEL_SHUFFLE = 38 TRANSFORM_CHANNEL_SHUFFLE = 38
UPDATE_DATA_FORMAT = 39 UPDATE_DATA_FORMAT = 39
QUANTIZE_MATMUL_ONLY = 40
class ConverterInterface(object): class ConverterInterface(object):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import math import math
import numpy as np import numpy as np
import six import six
...@@ -288,12 +288,11 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -288,12 +288,11 @@ class TensorflowConverter(base_converter.ConverterInterface):
tf_graph_def.ParseFromString(f.read()) tf_graph_def.ParseFromString(f.read())
self._placeholders = {} self._placeholders = {}
self.add_shape_info(tf_graph_def)
print("Run transform_graph: %s" % TFTransformGraphOptions[ print("Run transform_graph: %s" % TFTransformGraphOptions[
option.device]) option.device])
try: try:
print ("output keys: ", option.output_nodes.keys()) print("output keys: ", option.output_nodes.keys())
transformed_graph_def = TransformGraph(tf_graph_def, transformed_graph_def = TransformGraph(tf_graph_def,
option.input_nodes.keys(), option.input_nodes.keys(),
option.output_nodes.keys(), option.output_nodes.keys(),
...@@ -303,6 +302,16 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -303,6 +302,16 @@ class TensorflowConverter(base_converter.ConverterInterface):
print("Failed to transform graph using tf tool: %s" % ex) print("Failed to transform graph using tf tool: %s" % ex)
transformed_graph_def = tf_graph_def transformed_graph_def = tf_graph_def
# To check optimized model, uncomment following code.
# tf.io.write_graph(
# transformed_graph_def,
# ".",
# os.path.basename(src_model_file)[:-3] + "_opt.pb",
# as_text=False
# )
self.add_shape_info(transformed_graph_def)
with tf.Session() as session: with tf.Session() as session:
with session.graph.as_default() as graph: with session.graph.as_default() as graph:
tf.import_graph_def(transformed_graph_def, name='') tf.import_graph_def(transformed_graph_def, name='')
......
...@@ -103,6 +103,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -103,6 +103,8 @@ class Transformer(base_converter.ConverterInterface):
self.transform_caffe_reshape_and_flatten, self.transform_caffe_reshape_and_flatten,
TransformerRule.TRANSFORM_CHANNEL_SHUFFLE: TransformerRule.TRANSFORM_CHANNEL_SHUFFLE:
self.transform_channel_shuffle, self.transform_channel_shuffle,
TransformerRule.QUANTIZE_MATMUL_ONLY:
self.quantize_matmul_only,
} }
self._option = option self._option = option
...@@ -191,16 +193,23 @@ class Transformer(base_converter.ConverterInterface): ...@@ -191,16 +193,23 @@ class Transformer(base_converter.ConverterInterface):
op = mace_pb2.OperatorDef() op = mace_pb2.OperatorDef()
op.name = self.normalize_op_name(input_node.name) op.name = self.normalize_op_name(input_node.name)
op.type = "Input" op.type = "Input"
data_type_arg = op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_FLOAT
op.output.extend([input_node.name]) op.output.extend([input_node.name])
output_shape = op.output_shape.add() output_shape = op.output_shape.add()
output_shape.dims.extend(input_node.shape) output_shape.dims.extend(input_node.shape)
if ConverterUtil.data_format( if input_node in self._consumers:
self._consumers[input_node.name][0]) \ if ConverterUtil.data_format(
== DataFormat.NCHW: self._consumers[input_node.name][0]) \
self.transpose_shape(output_shape.dims, [0, 3, 1, 2]) == DataFormat.NCHW:
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) self.transpose_shape(output_shape.dims,
else: [0, 3, 1, 2])
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) ConverterUtil.add_data_format_arg(op,
DataFormat.NCHW)
else:
ConverterUtil.add_data_format_arg(op,
DataFormat.NHWC)
self._producer[op.output[0]] = op self._producer[op.output[0]] = op
@staticmethod @staticmethod
...@@ -221,10 +230,32 @@ class Transformer(base_converter.ConverterInterface): ...@@ -221,10 +230,32 @@ class Transformer(base_converter.ConverterInterface):
return name.replace(':', '_') return name.replace(':', '_')
def get_tensor_shape(self, tensor): def get_tensor_shape(self, tensor):
producer = self._producer[tensor] if tensor in self._consts:
for i in six.moves.range(len(producer.output)): return list(self._consts[tensor].dims)
if producer.output[i] == tensor: elif tensor in self._producer:
return list(producer.output_shape[i].dims) producer = self._producer[tensor]
for i in six.moves.range(len(producer.output)):
if producer.output[i] == tensor:
return list(producer.output_shape[i].dims)
else:
return None
def get_tensor_data_type(self, tensor):
if tensor in self._consts:
return self._consts[tensor].data_type
elif tensor in self._producer:
producer = self._producer[tensor]
for i in six.moves.range(len(producer.output)):
if producer.output[i] == tensor:
if i < len(producer.output_type):
return producer.output_type[i]
elif ConverterUtil.get_arg(producer, "T") is not None:
return ConverterUtil.get_arg(producer, "T").i
else:
print("No data type filled: ", producer)
return None
else:
return None
def consumer_count(self, tensor_name): def consumer_count(self, tensor_name):
return len(self._consumers.get(tensor_name, [])) return len(self._consumers.get(tensor_name, []))
...@@ -1374,6 +1405,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1374,6 +1405,7 @@ class Transformer(base_converter.ConverterInterface):
return False return False
def update_data_format(self): def update_data_format(self):
print("update data format")
data_format_flag = DataFormat.NHWC.value data_format_flag = DataFormat.NHWC.value
for input_node in self._option.input_nodes.values(): for input_node in self._option.input_nodes.values():
if input_node.data_format.value == DataFormat.DF_NONE.value: if input_node.data_format.value == DataFormat.DF_NONE.value:
...@@ -1672,7 +1704,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1672,7 +1704,8 @@ class Transformer(base_converter.ConverterInterface):
quantize_util.adjust_range(input_node.range[0], quantize_util.adjust_range(input_node.range[0],
input_node.range[1], input_node.range[1],
non_zero=False) non_zero=False)
quantize_info = mace_pb2.QuantizeActivationInfo() quantize_info = \
mace_pb2.QuantizeActivationInfo()
quantize_info.minval = minval quantize_info.minval = minval
quantize_info.maxval = maxval quantize_info.maxval = maxval
quantize_info.scale = scale quantize_info.scale = scale
...@@ -1893,3 +1926,111 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1893,3 +1926,111 @@ class Transformer(base_converter.ConverterInterface):
producer_op.output_shape[0].dims[:] = output_shape producer_op.output_shape[0].dims[:] = output_shape
return True return True
def quantize_matmul_only(self):
"""
This transform rule is only used internally, we are not gonna make
things too complex for users
"""
to_quantize_ops = [MaceOp.MatMul.name]
for op in self._model.op:
if (op.type not in to_quantize_ops or len(op.output) > 1
or ConverterUtil.get_arg(op,
MaceKeyword.mace_op_data_type_str).i != mace_pb2.DT_FLOAT): # noqa
# only support single output
continue
quantized_inputs_names = []
should_quantize = True
for idx, input_tensor in enumerate(op.input):
if self.get_tensor_data_type(input_tensor) \
!= mace_pb2.DT_FLOAT:
should_quantize = False
break
if not should_quantize:
continue
non_zero = self._option.device == DeviceType.CPU.value
for idx, input_tensor in enumerate(op.input):
quantized_inputs_names.append(input_tensor)
if input_tensor in self._consts:
const_tensor = self._consts[input_tensor]
quantized_tensor = quantize_util.quantize(
const_tensor.float_data, non_zero)
del const_tensor.float_data[:]
const_tensor.int32_data.extend(quantized_tensor.data)
const_tensor.data_type = mace_pb2.DT_UINT8
const_tensor.scale = quantized_tensor.scale
const_tensor.zero_point = quantized_tensor.zero
const_tensor.minval = quantized_tensor.minval
const_tensor.maxval = quantized_tensor.maxval
const_tensor.quantized = True
else:
input_shape = self.get_tensor_shape(input_tensor)
quantize_op = self._model.op.add()
quantize_op.name = self.normalize_op_name(
input_tensor) + "_quant"
quantize_op.type = MaceOp.Quantize.name
quantize_op.input.extend([input_tensor])
quantize_output_name = quantize_op.name + '_0'
quantize_op.output.extend([quantize_output_name])
output_shape = quantize_op.output_shape.add()
output_shape.dims.extend(input_shape)
quantize_op.output_type.extend([mace_pb2.DT_UINT8])
data_type_arg = quantize_op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_UINT8
data_type_arg = quantize_op.arg.add()
data_type_arg.name = MaceKeyword.mace_non_zero
if non_zero:
data_type_arg.i = 1
else:
data_type_arg.i = 0
find_range_arg = quantize_op.arg.add()
find_range_arg.name = \
MaceKeyword.mace_find_range_every_time
find_range_arg.i = 1
quantized_inputs_names[-1] = quantize_output_name
non_zero = False
del op.input[:]
op.input.extend(quantized_inputs_names)
orginal_output_name = op.output[0]
op.output[0] = orginal_output_name + "_quant"
op.output_type.extend([mace_pb2.DT_INT32])
data_type_arg = ConverterUtil.get_arg(op,
MaceKeyword.mace_op_data_type_str) # noqa
if data_type_arg is None:
data_type_arg = op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_UINT8
dequantize_op = self._model.op.add()
dequantize_op.name = op.name + "_dequant"
dequantize_op.type = MaceOp.Dequantize.name
dequantize_op.input.extend([op.output[0]])
dequantize_op.output.extend([orginal_output_name])
dequantize_op.output_shape.extend(op.output_shape)
dequantize_op.output_type.extend([mace_pb2.DT_FLOAT])
data_type_arg = dequantize_op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_INT32
quantize_flag_arg = ConverterUtil.get_arg(self._model,
MaceKeyword.mace_quantize_flag_arg_str) # noqa
if quantize_flag_arg is None:
quantize_flag_arg = self._model.arg.add()
quantize_flag_arg.name = MaceKeyword.mace_quantize_flag_arg_str
quantize_flag_arg.i = 1
return True
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册