diff --git a/mace/kernels/concat.h b/mace/kernels/concat.h index 14bf38cdce8d0627c2441649e93611bf9ae096b1..0cb28861648b2221c41e6225afe0743ab36e5d9b 100644 --- a/mace/kernels/concat.h +++ b/mace/kernels/concat.h @@ -23,6 +23,7 @@ #include "mace/core/types.h" #include "mace/kernels/kernel.h" #include "mace/public/mace.h" +#include "mace/utils/quantize.h" namespace mace { namespace kernels { @@ -48,9 +49,6 @@ struct ConcatFunctor : OpKernel { outer_sizes[0] = input0->size() / inner_size; for (size_t i = 1; i < inputs_count; ++i) { const Tensor *input = input_list[i]; - MACE_CHECK(input->scale() == output->scale() - && input->zero_point() == output->zero_point(), - "Inputs and output must have the same scale and zero_point."); MACE_CHECK(input->dim_size() == input0->dim_size(), "Ranks of all input tensors must be same."); for (int j = 0; j < input->dim_size(); ++j) { @@ -91,6 +89,76 @@ struct ConcatFunctor : OpKernel { int32_t axis_; }; +template<> +struct ConcatFunctor : OpKernel { + ConcatFunctor(OpKernelContext *context, const int32_t axis) + : OpKernel(context), axis_(axis) {} + + MaceStatus operator()(const std::vector &input_list, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + MACE_CHECK(output->scale() != 0); + const Tensor *input0 = input_list.front(); + const size_t inputs_count = input_list.size(); + + std::vector output_shape(input0->shape()); + index_t inner_size = 1; + for (int i = 0; i < axis_; ++i) { + inner_size *= output_shape[i]; + } + std::vector outer_sizes(inputs_count, 0); + outer_sizes[0] = input0->size() / inner_size; + for (size_t i = 1; i < inputs_count; ++i) { + const Tensor *input = input_list[i]; + MACE_CHECK(input->dim_size() == input0->dim_size(), + "Ranks of all input tensors must be same."); + for (int j = 0; j < input->dim_size(); ++j) { + if (j == axis_) { + continue; + } + MACE_CHECK(input->dim(j) == input0->dim(j), + "Dimensions of inputs should equal except axis."); + } + outer_sizes[i] = input->size() / inner_size; + output_shape[axis_] += input->dim(axis_); + } + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + auto output_ptr = output->mutable_data(); + + std::vector input_ptrs(input_list.size(), nullptr); + for (size_t i = 0; i < inputs_count; ++i) { + input_ptrs[i] = input_list[i]->data(); + } + + for (int inner_idx = 0; inner_idx < inner_size; ++inner_idx) { + for (size_t i = 0; i < inputs_count; ++i) { + if (input_list[i]->zero_point() == output->zero_point() + && input_list[i]->scale() == output->scale()) { + memcpy(output_ptr, input_ptrs[i], outer_sizes[i] * sizeof(uint8_t)); + output_ptr += outer_sizes[i]; + input_ptrs[i] += outer_sizes[i]; + } else { + const float scale = input_list[i]->scale() / output->scale(); + const float offset = + -input_list[i]->zero_point() * scale + output->zero_point(); + for (index_t k = 0; k < outer_sizes[i]; ++k) { + float out = (*input_ptrs[i]) * scale + offset; + *output_ptr = Saturate(roundf(out)); + ++output_ptr; + ++input_ptrs[i]; + } + } + } + } + + return MACE_SUCCESS; + } + + int32_t axis_; +}; + #ifdef MACE_ENABLE_OPENCL class OpenCLConcatKernel { public: diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h index 590ea11758c8f8cbc1591c348784a01afb342db7..000fa269bcf7c97f3a3ac6f38abac25ce3330015 100644 --- a/mace/kernels/pooling.h +++ b/mace/kernels/pooling.h @@ -277,6 +277,7 @@ struct PoolingFunctor: PoolingFunctorBase { uint8_t *out_ptr = output + ((b * out_height + h) * out_width + w) * channels; + std::fill_n(out_ptr, channels, 0); for (index_t ih = in_h_begin; ih < in_h_end; ++ih) { for (index_t iw = in_w_begin; iw < in_w_end; ++iw) { const uint8_t *in_ptr = input + diff --git a/mace/ops/concat_benchmark.cc b/mace/ops/concat_benchmark.cc index faf784c55d1aecc188572b96516d57596a7d8eba..5375cb6d487c6d8211aaca90bd8fc3cd23de56ec 100644 --- a/mace/ops/concat_benchmark.cc +++ b/mace/ops/concat_benchmark.cc @@ -22,7 +22,7 @@ namespace test { namespace { template -void ConcatHelper(int iters, int concat_dim, int dim1) { +void ConcatHelper(int iters, int concat_dim, int dim0, int dim1) { mace::testing::StopTiming(); OpsTestNet net; @@ -31,37 +31,49 @@ void ConcatHelper(int iters, int concat_dim, int dim1) { .Input("Input1") .AddIntArg("axis", concat_dim) .Output("Output") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); // Add input data - const int kDim0 = 100; - net.AddRandomInput("Input0", {kDim0, dim1}); - net.AddRandomInput("Input1", {kDim0, dim1}); + net.AddRandomInput("Input0", {dim0, dim1}); + net.AddRandomInput("Input1", {dim0, dim1}); + + net.Setup(D); + if (DataTypeToEnum::value == DT_UINT8) { + net.GetTensor("Input0")->SetScale(0.1); + net.GetTensor("Input1")->SetScale(0.2); + net.GetTensor("Output")->SetScale(0.3); + } // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); + for (int i = 0; i < 2; ++i) { + net.Run(); } - const int64_t tot = static_cast(iters) * kDim0 * dim1 * 2; + const int64_t tot = static_cast(iters) * dim0 * dim1 * 2; mace::testing::MaccProcessed(tot); testing::BytesProcessed(tot * sizeof(T)); mace::testing::StartTiming(); while (iters--) { - net.RunOp(D); + net.Run(); } } } // namespace -#define MACE_BM_CONCAT_CPU_MACRO(DIM0, DIM1) \ - static void MACE_BM_CONCAT_CPU_##DIM0##_##DIM1(int iters) { \ - ConcatHelper(iters, DIM0, DIM1); \ - } \ - MACE_BENCHMARK(MACE_BM_CONCAT_CPU_##DIM0##_##DIM1) +#define MACE_BM_CONCAT_CPU_MACRO(AXIS, DIM0, DIM1, TYPE) \ + static void MACE_BM_CONCAT_CPU_##AXIS##_##DIM0##_##DIM1##_##TYPE(int iters) {\ + ConcatHelper(iters, AXIS, DIM0, DIM1); \ + } \ + MACE_BENCHMARK(MACE_BM_CONCAT_CPU_##AXIS##_##DIM0##_##DIM1##_##TYPE) + +#define MACE_BM_CONCAT_CPU(AXIS, DIM0, DIM1) \ + MACE_BM_CONCAT_CPU_MACRO(AXIS, DIM0, DIM1, float); \ + MACE_BM_CONCAT_CPU_MACRO(AXIS, DIM0, DIM1, uint8_t); \ -MACE_BM_CONCAT_CPU_MACRO(0, 1000); -MACE_BM_CONCAT_CPU_MACRO(0, 100000); -MACE_BM_CONCAT_CPU_MACRO(1, 1000); -MACE_BM_CONCAT_CPU_MACRO(1, 100000); +MACE_BM_CONCAT_CPU(0, 100, 1000); +MACE_BM_CONCAT_CPU(0, 100, 100000); +MACE_BM_CONCAT_CPU(1, 100, 1000); +MACE_BM_CONCAT_CPU(1, 100, 100000); +MACE_BM_CONCAT_CPU(1, 1225, 128); namespace { template diff --git a/mace/ops/concat_test.cc b/mace/ops/concat_test.cc index f8b6b42a7824d3ee2824ca60fffd585e0daf864c..671b8f617e700ba8e25805e1c802f04c692cf657 100644 --- a/mace/ops/concat_test.cc +++ b/mace/ops/concat_test.cc @@ -154,6 +154,95 @@ TEST_F(ConcatOpTest, CPURandom) { } } +TEST_F(ConcatOpTest, QuantizedCPURandom) { + static unsigned int seed = time(NULL); + int dim = 4; + int num_inputs = 2 + rand_r(&seed) % 10; + int axis = rand_r(&seed) % dim; + // Construct graph + OpsTestNet net; + + std::vector shape_data; + GenerateRandomIntTypeData({dim}, &shape_data, 1, 50); + std::vector> input_shapes(num_inputs, shape_data); + std::vector> inputs(num_inputs, std::vector()); + std::vector input_ptrs(num_inputs, nullptr); + index_t concat_axis_size = 0; + for (int i = 0; i < num_inputs; ++i) { + input_shapes[i][axis] = 1 + rand_r(&seed) % dim; + concat_axis_size += input_shapes[i][axis]; + GenerateRandomRealTypeData(input_shapes[i], &inputs[i]); + input_ptrs[i] = inputs[i].data(); + net.AddInputFromArray(MakeString("Input", i), + input_shapes[i], inputs[i]); + } + std::vector output_shape = input_shapes[0]; + output_shape[axis] = concat_axis_size; + net.AddRandomInput( + "Output", output_shape, true, true); + + auto builder = OpDefBuilder("Concat", "ConcatTest"); + for (int i = 0; i < num_inputs; ++i) { + builder = builder.Input(MakeString("Input", i)); + } + builder.AddIntArg("axis", axis) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + for (int i = 0; i < num_inputs; ++i) { + OpDefBuilder("Quantize", MakeString("QuantizeInput", i)) + .Input(MakeString("Input", i)) + .Output(MakeString("QuantizedInput", i)) + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + } + + OpDefBuilder("Quantize", "QuantizeOutput") + .Input("Output") + .Output("ExpectedQuantizedOutput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + net.AddRandomInput( + "QuantizedOutput", output_shape, true, true); + auto q_builder = OpDefBuilder("Concat", "QuantizedConcatTest"); + for (int i = 0; i < num_inputs; ++i) { + q_builder = q_builder.Input(MakeString("QuantizedInput", i)); + } + q_builder.AddIntArg("axis", axis) + .Output("QuantizedOutput") + .AddIntArg("T", static_cast(DT_UINT8)) + .Finalize(net.NewOperatorDef()); + + net.Setup(DeviceType::CPU); + Tensor *eq_output = net.GetTensor("ExpectedQuantizedOutput"); + Tensor *q_output = net.GetTensor("QuantizedOutput"); + q_output->SetScale(eq_output->scale()); + q_output->SetZeroPoint(eq_output->zero_point()); + net.Run(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} + namespace { template void OpenclRandomTest(const std::vector> &shapes, diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 296fa3b9f2811a6d9987ec141b8130a4ba2cb151..a3b8c4d9168168eed9d7c44edaf22b4081ab1fd6 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -34,6 +34,7 @@ #include "mace/kernels/opencl/common.h" #include "mace/ops/ops_register.h" #include "mace/utils/utils.h" +#include "mace/utils/quantize.h" namespace mace { namespace ops { @@ -200,6 +201,11 @@ class OpsTestNet { } return half_float::half_cast(positive ? std::abs(d) : d); }); + } else if (DataTypeToEnum::value == DT_UINT8) { + std::generate(input_data, input_data + input->size(), + [&gen, &nd] { + return Saturate(roundf((nd(gen) + 1) * 128)); + }); } else { std::generate(input_data, input_data + input->size(), [&gen, &nd, positive, truncate, diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index 2f02d729ed45aa6a160af5d42d09bcc915650481..c22e9b133b3c57a023be8cedfc318e7404dd6155 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -578,11 +578,14 @@ void TestQuant(const index_t batch, enum Padding padding_type, PoolingType pooling) { OpsTestNet net; + std::vector input_shape{batch, in_height, in_width, channels}; net.AddRandomInput( - "Input", {batch, in_height, in_width, channels}, false); + "Input", input_shape, false); net.TransformDataFormat( "Input", NHWC, "InputNCHW", NCHW); + net.AddRandomInput( + "OutputNCHW", input_shape, true, true); OpDefBuilder("Pooling", "PoolingTest") .Input("InputNCHW") .Output("OutputNCHW") @@ -607,6 +610,7 @@ void TestQuant(const index_t batch, .Finalize(net.NewOperatorDef()); net.RunOp(); + net.AddRandomInput("QuantizedOutput", input_shape); OpDefBuilder("Pooling", "PoolingTest") .Input("QuantizedInput") .Output("QuantizedOutput") diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index bcbb743918454a1639c2b9eef125fcfa500bb4a2..111030c8a7288e14e9cda71b88ad0aabba68335a 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -106,6 +106,7 @@ def main(unused_args): option.winograd = FLAGS.winograd option.quantize = FLAGS.quantize option.quantize_range_file = FLAGS.quantize_range_file + option.change_concat_ranges = FLAGS.change_concat_ranges option.cl_mem_type = FLAGS.cl_mem_type input_node_names = FLAGS.input_node.split(',') @@ -324,6 +325,13 @@ def parse_args(): type=str, default="", help="file path of quantize range for each tensor") + parser.add_argument( + "--change_concat_ranges", + type=str2bool, + nargs='?', + const=False, + default=False, + help="change ranges to use memcpy for quantized concat") parser.add_argument( "--cl_mem_type", type=str, diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index b7974a245113537cb604fcfcd1c62f864863406f..64cbc3e65301276eb987e93eb03191df9e751542 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -266,6 +266,7 @@ class ConverterOption(object): self._winograd = 0 self._quantize = False self._quantize_range_file = "" + self._change_concat_ranges = False self._transformer_option = None self._cl_mem_type = "" @@ -293,6 +294,10 @@ class ConverterOption(object): def quantize(self): return self._quantize + @property + def change_concat_ranges(self): + return self._change_concat_ranges + @property def quantize_range_file(self): return self._quantize_range_file @@ -341,6 +346,10 @@ class ConverterOption(object): def quantize_range_file(self, quantize_range_file): self._quantize_range_file = quantize_range_file + @change_concat_ranges.setter + def change_concat_ranges(self, change_concat_ranges): + self._change_concat_ranges = change_concat_ranges + @transformer_option.setter def transformer_option(self, transformer_option): self._transformer_option = transformer_option diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 982ab5083c474ac0e6358841cdc49be4f3a4a701..e159a9446d67f6613dceb5e092b25e1e1e297e6a 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1852,7 +1852,9 @@ class Transformer(base_converter.ConverterInterface): self.copy_quantize_info(op, producer_op.quantize_info[0]) self._quantize_activation_info[op.output[0]] = \ op.quantize_info[0] - elif op.type == MaceOp.Concat.name: + elif (op.type == MaceOp.Concat.name + and (not op.quantize_info + or self._option.change_concat_ranges)): if op.quantize_info: maxval = op.quantize_info[0].maxval minval = op.quantize_info[0].minval @@ -1866,12 +1868,13 @@ class Transformer(base_converter.ConverterInterface): quantize_info = \ self.add_quantize_info(op, minval, maxval) self._quantize_activation_info[op.output[0]] = quantize_info - for i in range(len(op.input)): - producer_op = self._producer[op.input[i]] - del producer_op.quantize_info[:] - self.copy_quantize_info(producer_op, quantize_info) - self._quantize_activation_info[producer_op.output[0]] = \ - producer_op.quantize_info[0] + if self._option.change_concat_ranges: + for i in range(len(op.input)): + producer_op = self._producer[op.input[i]] + del producer_op.quantize_info[:] + self.copy_quantize_info(producer_op, quantize_info) + self._quantize_activation_info[producer_op.output[0]] \ + = producer_op.quantize_info[0] elif op.type == MaceOp.Softmax.name: del op.quantize_info[:] quantize_info = \ diff --git a/tools/converter.py b/tools/converter.py index 0e4e789b3622fc9a9a0758cdf5304fe80aae124c..203c3dc1f2ea6332e1055d278291f257c746baef 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -200,6 +200,7 @@ class YAMLKeyword(object): winograd = 'winograd' quantize = 'quantize' quantize_range_file = 'quantize_range_file' + change_concat_ranges = 'change_concat_ranges' validation_inputs_data = 'validation_inputs_data' validation_threshold = 'validation_threshold' graph_optimize_options = 'graph_optimize_options' # internal use for now @@ -518,7 +519,8 @@ def format_model_config(flags): YAMLKeyword.nnlib_graph_mode, YAMLKeyword.obfuscate, YAMLKeyword.winograd, - YAMLKeyword.quantize]: + YAMLKeyword.quantize, + YAMLKeyword.change_concat_ranges]: value = model_config.get(key, "") if value == "": model_config[key] = 0 @@ -771,6 +773,7 @@ def convert_model(configs, cl_mem_type): model_config[YAMLKeyword.winograd], model_config[YAMLKeyword.quantize], model_config.get(YAMLKeyword.quantize_range_file, ""), + model_config[YAMLKeyword.change_concat_ranges], model_config[YAMLKeyword.obfuscate], configs[YAMLKeyword.model_graph_format], data_type, diff --git a/tools/sh_commands.py b/tools/sh_commands.py index a5e0dfa7da9e4163cf87c803001cd76563addb64..304da31b7d83d6be2826e515111395b1c81b923e 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -557,6 +557,7 @@ def gen_model_code(model_codegen_dir, winograd, quantize, quantize_range_file, + change_concat_ranges, obfuscate, model_graph_format, data_type, @@ -587,6 +588,7 @@ def gen_model_code(model_codegen_dir, "--winograd=%s" % winograd, "--quantize=%s" % quantize, "--quantize_range_file=%s" % quantize_range_file, + "--change_concat_ranges=%s" % change_concat_ranges, "--obfuscate=%s" % obfuscate, "--output_dir=%s" % model_codegen_dir, "--model_graph_format=%s" % model_graph_format,