diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index d345427a2a3a6258d90f3a55c71f1b8d8004419b..9323e7b522904e76e267ca222ef9423702d03709 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -926,7 +926,9 @@ class EltwiseOp : public Operation { const Tensor *input1, Tensor *output) { bool swapped = false; - if (input0->size() < input1->size()) { + if (input0->dim_size() < input1->dim_size() + || (input0->dim_size() == input1->dim_size() + && input0->size() < input1->size())) { std::swap(input0, input1); swapped = true; } diff --git a/mace/ops/gather_test.cc b/mace/ops/gather_test.cc index 32f849c7abf69318fe2fdcb9dcacb97bc437aec0..2c0f474ca7aa9437328f0319ba6538c11f538d3d 100644 --- a/mace/ops/gather_test.cc +++ b/mace/ops/gather_test.cc @@ -23,53 +23,67 @@ namespace test { class GatherOpTest : public OpsTestBase {}; namespace { +template void TestGather(const std::vector &weight_shape, - const std::vector &weight, + const std::vector &weight, const std::vector &input_shape, const std::vector &input, const int axis, const std::vector &output_shape, - const std::vector &output) { + const std::vector &output) { OpsTestNet net; - net.AddInputFromArray("Params", weight_shape, weight); + net.AddInputFromArray("Params", weight_shape, weight); net.AddInputFromArray("Indices", input_shape, input); OpDefBuilder("Gather", "GatherTest") .Input("Params") .Input("Indices") + .AddIntArg("T", DataTypeToEnum::v()) .AddIntArg("axis", axis) .Output("Output") .Finalize(net.NewOperatorDef()); // Run net.RunOp(CPU); - auto expected = net.CreateTensor(output_shape, output); + auto expected = net.CreateTensor(output_shape, output); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } } // namespace TEST_F(GatherOpTest, CPUScalarIndex) { - TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + {}, {5}, 0, {2}, {10, 11}); + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, {}, {5}, 0, {2}, {10, 11}); } TEST_F(GatherOpTest, CPURank1Index) { - TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + {3}, {2, 4, 6}, 0, {3, 2}, {4, 5, 8, 9, 12, 13}); + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, {3}, {2, 4, 6}, 0, {3, 2}, {4, 5, 8, 9, 12, 13}); } TEST_F(GatherOpTest, CPURank1IndexWithAxis1) { - TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + {1}, {1}, 1, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19}); + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, {1}, {1}, 1, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19}); } TEST_F(GatherOpTest, CPURankHighIndex) { - TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + {1, 3}, {2, 4, 6}, 0, {1, 3, 2}, {4, 5, 8, 9, 12, 13}); + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, {1, 3}, {2, 4, 6}, 0, {1, 3, 2}, {4, 5, 8, 9, 12, 13}); } diff --git a/mace/ops/matmul.cc b/mace/ops/matmul.cc index 07a65d79459166281a6f13cb4d58817a69d0f3ac..907aec7a4e7de79ee24cd01c67d1e1115565a5d8 100644 --- a/mace/ops/matmul.cc +++ b/mace/ops/matmul.cc @@ -233,8 +233,8 @@ class MatMulFixpointImpl { const index_t height, const index_t K, const index_t width, - const bool lhs_bached, - const bool rhs_bached, + const bool lhs_batched, + const bool rhs_batched, Tensor *C) { #if defined(MACE_ENABLE_NEON) if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { @@ -245,8 +245,8 @@ class MatMulFixpointImpl { batch, height, K, - true, - true, + lhs_batched, + rhs_batched, C); } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { gemv_kernel_.Compute(context, @@ -256,8 +256,8 @@ class MatMulFixpointImpl { batch, width, K, - true, - true, + lhs_batched, + rhs_batched, C); } else { #endif // MACE_ENABLE_NEON @@ -281,11 +281,13 @@ class MatMulFixpointImpl { for (index_t i = 0; i < batch; ++i) { gemmlowp::MatrixMap - a_matrix(a_ptr_base + static_cast(lhs_bached) * i * a_size, + a_matrix(a_ptr_base + + static_cast(lhs_batched) * i * a_size, height, K); gemmlowp::MatrixMap - b_matrix(b_ptr_base + static_cast(rhs_bached) * i * b_size, + b_matrix(b_ptr_base + + static_cast(rhs_batched) * i * b_size, K, width); gemmlowp::MatrixMap @@ -315,8 +317,8 @@ class MatMulFixpointImpl { const index_t height, const index_t K, const index_t width, - const bool lhs_bached, - const bool rhs_bached, + const bool lhs_batched, + const bool rhs_batched, Tensor *C) { C->SetScale(A->scale() * B->scale()); C->SetZeroPoint(0); @@ -330,8 +332,8 @@ class MatMulFixpointImpl { batch, height, K, - lhs_bached, - rhs_bached, + lhs_batched, + rhs_batched, C); } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { gemv_kernel_.Compute(context, @@ -341,8 +343,8 @@ class MatMulFixpointImpl { batch, width, K, - lhs_bached, - rhs_bached, + lhs_batched, + rhs_batched, C); } else { #endif // MACE_ENABLE_NEON @@ -366,12 +368,12 @@ class MatMulFixpointImpl { for (index_t i = 0; i < batch; ++i) { gemmlowp::MatrixMap a_matrix - (a_ptr_base + static_cast(lhs_bached) * i * a_size, + (a_ptr_base + static_cast(lhs_batched) * i * a_size, height, K); gemmlowp::MatrixMap b_matrix - (b_ptr_base + static_cast(rhs_bached) * i * b_size, + (b_ptr_base + static_cast(rhs_batched) * i * b_size, K, width); gemmlowp::MatrixMap diff --git a/mace/ops/matmul_test.cc b/mace/ops/matmul_test.cc index 741393ffea45d435b52156f38b3a3ddc4d0e5b84..f88ac39435e328ad2a4ada6b3c41a73558fdb791 100644 --- a/mace/ops/matmul_test.cc +++ b/mace/ops/matmul_test.cc @@ -135,7 +135,8 @@ void Complex(const std::vector &batch, rhs_batched, &expected_output_tensor); - ExpectTensorNear(expected_output_tensor, *net.GetTensor("Output")); + ExpectTensorNear(expected_output_tensor, *net.GetTensor("Output"), + 1e-4, 1e-2); } } // namespace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 14610ff2b9b9bbe89de4977722cd2668e90ae147..03b1e7c3eff6f99dab1e1a36c1dd25edd0ba75a5 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -236,6 +236,7 @@ class MaceKeyword(object): mace_step_h_str = 'step_h' mace_step_w_str = 'step_w' mace_find_range_every_time = 'find_range_every_time' + mace_non_zero = 'non_zero' mace_pad_type_str = 'pad_type' @@ -279,6 +280,7 @@ class TransformerRule(Enum): FOLD_FC_RESHAPE = 37 TRANSFORM_CHANNEL_SHUFFLE = 38 UPDATE_DATA_FORMAT = 39 + QUANTIZE_MATMUL_ONLY = 40 class ConverterInterface(object): diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index d5be0ee87562f5074fc0cc571b7cfad901edeae1..eddb8d8685972a8dbc05070f444653288446657a 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import os import math import numpy as np import six @@ -288,12 +288,11 @@ class TensorflowConverter(base_converter.ConverterInterface): tf_graph_def.ParseFromString(f.read()) self._placeholders = {} - self.add_shape_info(tf_graph_def) print("Run transform_graph: %s" % TFTransformGraphOptions[ option.device]) try: - print ("output keys: ", option.output_nodes.keys()) + print("output keys: ", option.output_nodes.keys()) transformed_graph_def = TransformGraph(tf_graph_def, option.input_nodes.keys(), option.output_nodes.keys(), @@ -303,6 +302,16 @@ class TensorflowConverter(base_converter.ConverterInterface): print("Failed to transform graph using tf tool: %s" % ex) transformed_graph_def = tf_graph_def + # To check optimized model, uncomment following code. + # tf.io.write_graph( + # transformed_graph_def, + # ".", + # os.path.basename(src_model_file)[:-3] + "_opt.pb", + # as_text=False + # ) + + self.add_shape_info(transformed_graph_def) + with tf.Session() as session: with session.graph.as_default() as graph: tf.import_graph_def(transformed_graph_def, name='') diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 2147c6e1cd28c220b01f519b7b38ff21e904fbcd..9138b260beebab17cc12c158727ee33b4276c363 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -103,6 +103,8 @@ class Transformer(base_converter.ConverterInterface): self.transform_caffe_reshape_and_flatten, TransformerRule.TRANSFORM_CHANNEL_SHUFFLE: self.transform_channel_shuffle, + TransformerRule.QUANTIZE_MATMUL_ONLY: + self.quantize_matmul_only, } self._option = option @@ -191,16 +193,23 @@ class Transformer(base_converter.ConverterInterface): op = mace_pb2.OperatorDef() op.name = self.normalize_op_name(input_node.name) op.type = "Input" + data_type_arg = op.arg.add() + data_type_arg.name = MaceKeyword.mace_op_data_type_str + data_type_arg.i = mace_pb2.DT_FLOAT op.output.extend([input_node.name]) output_shape = op.output_shape.add() output_shape.dims.extend(input_node.shape) - if ConverterUtil.data_format( - self._consumers[input_node.name][0]) \ - == DataFormat.NCHW: - self.transpose_shape(output_shape.dims, [0, 3, 1, 2]) - ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) - else: - ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) + if input_node in self._consumers: + if ConverterUtil.data_format( + self._consumers[input_node.name][0]) \ + == DataFormat.NCHW: + self.transpose_shape(output_shape.dims, + [0, 3, 1, 2]) + ConverterUtil.add_data_format_arg(op, + DataFormat.NCHW) + else: + ConverterUtil.add_data_format_arg(op, + DataFormat.NHWC) self._producer[op.output[0]] = op @staticmethod @@ -221,10 +230,32 @@ class Transformer(base_converter.ConverterInterface): return name.replace(':', '_') def get_tensor_shape(self, tensor): - producer = self._producer[tensor] - for i in six.moves.range(len(producer.output)): - if producer.output[i] == tensor: - return list(producer.output_shape[i].dims) + if tensor in self._consts: + return list(self._consts[tensor].dims) + elif tensor in self._producer: + producer = self._producer[tensor] + for i in six.moves.range(len(producer.output)): + if producer.output[i] == tensor: + return list(producer.output_shape[i].dims) + else: + return None + + def get_tensor_data_type(self, tensor): + if tensor in self._consts: + return self._consts[tensor].data_type + elif tensor in self._producer: + producer = self._producer[tensor] + for i in six.moves.range(len(producer.output)): + if producer.output[i] == tensor: + if i < len(producer.output_type): + return producer.output_type[i] + elif ConverterUtil.get_arg(producer, "T") is not None: + return ConverterUtil.get_arg(producer, "T").i + else: + print("No data type filled: ", producer) + return None + else: + return None def consumer_count(self, tensor_name): return len(self._consumers.get(tensor_name, [])) @@ -1374,6 +1405,7 @@ class Transformer(base_converter.ConverterInterface): return False def update_data_format(self): + print("update data format") data_format_flag = DataFormat.NHWC.value for input_node in self._option.input_nodes.values(): if input_node.data_format.value == DataFormat.DF_NONE.value: @@ -1672,7 +1704,8 @@ class Transformer(base_converter.ConverterInterface): quantize_util.adjust_range(input_node.range[0], input_node.range[1], non_zero=False) - quantize_info = mace_pb2.QuantizeActivationInfo() + quantize_info = \ + mace_pb2.QuantizeActivationInfo() quantize_info.minval = minval quantize_info.maxval = maxval quantize_info.scale = scale @@ -1893,3 +1926,111 @@ class Transformer(base_converter.ConverterInterface): producer_op.output_shape[0].dims[:] = output_shape return True + + def quantize_matmul_only(self): + """ + This transform rule is only used internally, we are not gonna make + things too complex for users + """ + to_quantize_ops = [MaceOp.MatMul.name] + for op in self._model.op: + if (op.type not in to_quantize_ops or len(op.output) > 1 + or ConverterUtil.get_arg(op, + MaceKeyword.mace_op_data_type_str).i != mace_pb2.DT_FLOAT): # noqa + # only support single output + continue + + quantized_inputs_names = [] + should_quantize = True + for idx, input_tensor in enumerate(op.input): + if self.get_tensor_data_type(input_tensor) \ + != mace_pb2.DT_FLOAT: + should_quantize = False + break + + if not should_quantize: + continue + + non_zero = self._option.device == DeviceType.CPU.value + + for idx, input_tensor in enumerate(op.input): + quantized_inputs_names.append(input_tensor) + + if input_tensor in self._consts: + const_tensor = self._consts[input_tensor] + quantized_tensor = quantize_util.quantize( + const_tensor.float_data, non_zero) + del const_tensor.float_data[:] + const_tensor.int32_data.extend(quantized_tensor.data) + const_tensor.data_type = mace_pb2.DT_UINT8 + const_tensor.scale = quantized_tensor.scale + const_tensor.zero_point = quantized_tensor.zero + const_tensor.minval = quantized_tensor.minval + const_tensor.maxval = quantized_tensor.maxval + const_tensor.quantized = True + else: + input_shape = self.get_tensor_shape(input_tensor) + quantize_op = self._model.op.add() + quantize_op.name = self.normalize_op_name( + input_tensor) + "_quant" + quantize_op.type = MaceOp.Quantize.name + quantize_op.input.extend([input_tensor]) + quantize_output_name = quantize_op.name + '_0' + quantize_op.output.extend([quantize_output_name]) + output_shape = quantize_op.output_shape.add() + output_shape.dims.extend(input_shape) + quantize_op.output_type.extend([mace_pb2.DT_UINT8]) + data_type_arg = quantize_op.arg.add() + data_type_arg.name = MaceKeyword.mace_op_data_type_str + data_type_arg.i = mace_pb2.DT_UINT8 + + data_type_arg = quantize_op.arg.add() + data_type_arg.name = MaceKeyword.mace_non_zero + if non_zero: + data_type_arg.i = 1 + else: + data_type_arg.i = 0 + + find_range_arg = quantize_op.arg.add() + find_range_arg.name = \ + MaceKeyword.mace_find_range_every_time + find_range_arg.i = 1 + + quantized_inputs_names[-1] = quantize_output_name + + non_zero = False + + del op.input[:] + op.input.extend(quantized_inputs_names) + + orginal_output_name = op.output[0] + op.output[0] = orginal_output_name + "_quant" + op.output_type.extend([mace_pb2.DT_INT32]) + data_type_arg = ConverterUtil.get_arg(op, + MaceKeyword.mace_op_data_type_str) # noqa + if data_type_arg is None: + data_type_arg = op.arg.add() + data_type_arg.name = MaceKeyword.mace_op_data_type_str + data_type_arg.i = mace_pb2.DT_UINT8 + + dequantize_op = self._model.op.add() + dequantize_op.name = op.name + "_dequant" + dequantize_op.type = MaceOp.Dequantize.name + dequantize_op.input.extend([op.output[0]]) + dequantize_op.output.extend([orginal_output_name]) + dequantize_op.output_shape.extend(op.output_shape) + dequantize_op.output_type.extend([mace_pb2.DT_FLOAT]) + data_type_arg = dequantize_op.arg.add() + data_type_arg.name = MaceKeyword.mace_op_data_type_str + data_type_arg.i = mace_pb2.DT_INT32 + + quantize_flag_arg = ConverterUtil.get_arg(self._model, + MaceKeyword.mace_quantize_flag_arg_str) # noqa + if quantize_flag_arg is None: + quantize_flag_arg = self._model.arg.add() + quantize_flag_arg.name = MaceKeyword.mace_quantize_flag_arg_str + quantize_flag_arg.i = 1 + + return True + + return False