diff --git a/mace/kernels/matmul.cc b/mace/kernels/matmul.cc index 4723e6557cd20d41e7d60d61666a26109d515bb2..8ef93a29aa7ce2b2320d9ec1f776519c9df5f49c 100644 --- a/mace/kernels/matmul.cc +++ b/mace/kernels/matmul.cc @@ -68,7 +68,10 @@ class MatMulOpBase : public Operation { }; template -class MatMulOp : public MatMulOpBase { +class MatMulOp; + +template <> +class MatMulOp : public MatMulOpBase { public: explicit MatMulOp(OpConstructContext *context) : MatMulOpBase(context) {} @@ -107,9 +110,9 @@ class MatMulOp : public MatMulOpBase { Tensor::MappingGuard guarda(A); Tensor::MappingGuard guardb(B); Tensor::MappingGuard guardc(C); - const T *a_ptr_base = A->data(); - const T *b_ptr_base = B->data(); - T *c_ptr_base = C->mutable_data(); + const float *a_ptr_base = A->data(); + const float *b_ptr_base = B->data(); + float *c_ptr_base = C->mutable_data(); const index_t height_a = A->dim(rank - 2); const index_t width_a = A->dim(rank - 1); @@ -147,6 +150,100 @@ class MatMulOp : public MatMulOpBase { SGemm sgemm_; }; +template +class MatMulFixpointImpl; + +template +class MatMulFixpointImpl { + public: + void operator()(OpContext *context, + const Tensor *A, + const Tensor *B, + const index_t height, + const index_t K, + const index_t width, + Tensor *C) { + auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); + MACE_CHECK_NOTNULL(gemm_context); + + Tensor::MappingGuard guarda(A); + Tensor::MappingGuard guardb(B); + Tensor::MappingGuard guardc(C); + auto a_ptr_base = A->data(); + auto b_ptr_base = B->data(); + auto c_ptr_base = C->mutable_data(); + index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, + std::multiplies()); + index_t a_size = height * K; + index_t b_size = K * width; + index_t c_size = height * width; + + const auto &output_pipeline = GemmlowpOutputPipeline::MakeNoBias( + A->scale(), B->scale(), C->scale(), C->zero_point()); + + for (index_t i = 0; i < batch; ++i) { + gemmlowp::MatrixMap + a_matrix(a_ptr_base + i * a_size, height, K); + gemmlowp::MatrixMap + b_matrix(b_ptr_base + i * b_size, K, width); + gemmlowp::MatrixMap + c_matrix(c_ptr_base + i * c_size, height, width); + + using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; + gemmlowp::GemmWithOutputPipeline( + gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(), + -B->zero_point(), output_pipeline); + } + } +}; + +template +class MatMulFixpointImpl { + public: + void operator()(OpContext *context, + const Tensor *A, + const Tensor *B, + const index_t height, + const index_t K, + const index_t width, + Tensor *C) { + auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); + MACE_CHECK_NOTNULL(gemm_context); + + Tensor::MappingGuard guarda(A); + Tensor::MappingGuard guardb(B); + Tensor::MappingGuard guardc(C); + auto a_ptr_base = A->data(); + auto b_ptr_base = B->data(); + auto c_ptr_base = C->mutable_data(); + index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, + std::multiplies()); + index_t a_size = height * K; + index_t b_size = K * width; + index_t c_size = height * width; + + const auto output_pipeline = std::make_tuple(); + + for (index_t i = 0; i < batch; ++i) { + gemmlowp::MatrixMap + a_matrix(a_ptr_base + i * a_size, height, K); + gemmlowp::MatrixMap + b_matrix(b_ptr_base + i * b_size, K, width); + gemmlowp::MatrixMap + c_matrix(c_ptr_base + i * c_size, height, width); + + using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; + gemmlowp::GemmWithOutputPipeline( + gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(), + -B->zero_point(), output_pipeline); + } + + C->SetScale(A->scale() * B->scale()); + C->SetZeroPoint(0); + } +}; + template <> class MatMulOp: public MatMulOpBase { public: @@ -182,69 +279,37 @@ class MatMulOp: public MatMulOpBase { constexpr gemmlowp::MapOrder kRowMajor = gemmlowp::MapOrder::RowMajor; constexpr gemmlowp::MapOrder kColMajor = gemmlowp::MapOrder::ColMajor; -#define MATMUL_IMPL(AOrder, BOrder) \ - MatMulImpl(context, A, B, height, K, width, C); +#define MATMUL_FIXPOINT_IMPL(AOrder, BOrder, OutType) \ + MatMulFixpointImpl()( \ + context, A, B, height, K, width, C); + +#define MATMUL_FIXPOINT_IMPL_TRANSPOSE_OR_NOT(OutType) \ + if (transpose_a_) { \ + if (transpose_b_) { \ + MATMUL_FIXPOINT_IMPL(kColMajor, kColMajor, OutType); \ + } else { \ + MATMUL_FIXPOINT_IMPL(kColMajor, kRowMajor, OutType); \ + } \ + } else { \ + if (transpose_b_) { \ + MATMUL_FIXPOINT_IMPL(kRowMajor, kColMajor, OutType); \ + } else { \ + MATMUL_FIXPOINT_IMPL(kRowMajor, kRowMajor, OutType); \ + } \ + } - if (transpose_a_) { - if (transpose_b_) { - MATMUL_IMPL(kColMajor, kColMajor); - } else { - MATMUL_IMPL(kColMajor, kRowMajor); - } + if (!operator_def_->output_type().empty() + && operator_def_->output_type()[0] == DT_INT32) { + MATMUL_FIXPOINT_IMPL_TRANSPOSE_OR_NOT(int32_t); } else { - if (transpose_b_) { - MATMUL_IMPL(kRowMajor, kColMajor); - } else { - MATMUL_IMPL(kRowMajor, kRowMajor); - } + MATMUL_FIXPOINT_IMPL_TRANSPOSE_OR_NOT(uint8_t); } -#undef MATMUL_IMPL +#undef MATMUL_FIXPOINT_IMPL_TRANSPOSE_OR_NOT +#undef MATMUL_FIXPOINT_IMPL return MaceStatus::MACE_SUCCESS; } - - private: - template - void MatMulImpl(OpContext *context, - const Tensor *A, - const Tensor *B, - const index_t height, - const index_t K, - const index_t width, - Tensor *C) { - auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); - MACE_CHECK_NOTNULL(gemm_context); - - Tensor::MappingGuard guarda(A); - Tensor::MappingGuard guardb(B); - Tensor::MappingGuard guardc(C); - auto a_ptr_base = A->data(); - auto b_ptr_base = B->data(); - auto c_ptr_base = C->mutable_data(); - index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, - std::multiplies()); - index_t a_size = height * K; - index_t b_size = K * width; - index_t c_size = height * width; - - const auto &output_pipeline = GemmlowpOutputPipeline::MakeNoBias( - A->scale(), B->scale(), C->scale(), C->zero_point()); - - for (index_t i = 0; i < batch; ++i) { - gemmlowp::MatrixMap - a_matrix(a_ptr_base + i * a_size, height, K); - gemmlowp::MatrixMap - b_matrix(b_ptr_base + i * b_size, K, width); - gemmlowp::MatrixMap - c_matrix(c_ptr_base + i * c_size, height, width); - - using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; - gemmlowp::GemmWithOutputPipeline( - gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(), - -B->zero_point(), output_pipeline); - } - } }; #ifdef MACE_ENABLE_OPENCL diff --git a/mace/kernels/quantize.cc b/mace/kernels/quantize.cc index 2f2b8fc263f9f5ddeaf5aeb394f81283dddbc245..2fd9e7c313a54af1121feb28503a8490812d553b 100644 --- a/mace/kernels/quantize.cc +++ b/mace/kernels/quantize.cc @@ -72,8 +72,8 @@ class QuantizeOp : public Operation { template class DequantizeOp; -template <> -class DequantizeOp : public Operation { +template +class DequantizeOp : public Operation { public: explicit DequantizeOp(OpConstructContext *context) : Operation(context) {} @@ -85,9 +85,9 @@ class DequantizeOp : public Operation { MACE_RETURN_IF_ERROR(output->ResizeLike(input)); Tensor::MappingGuard input_guard(input); Tensor::MappingGuard output_guard(output); - const uint8_t *input_data = input->data(); + const T *input_data = input->data(); float *output_data = output->mutable_data(); - Dequantize(input_data, + Dequantize(input_data, input->size(), input->scale(), input->zero_point(), @@ -104,6 +104,8 @@ void RegisterQuantize(OpRegistryBase *op_registry) { void RegisterDequantize(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Dequantize", DequantizeOp, DeviceType::CPU, uint8_t); + MACE_REGISTER_OP(op_registry, "Dequantize", DequantizeOp, + DeviceType::CPU, int32_t); } } // namespace kernels } // namespace mace diff --git a/mace/ops/matmul_test.cc b/mace/ops/matmul_test.cc index 83958c75ef27bc29de4fc21626d25029e696bb28..e31d8616291ebbd8b29d1c6fa748493ea7b05f78 100644 --- a/mace/ops/matmul_test.cc +++ b/mace/ops/matmul_test.cc @@ -214,12 +214,12 @@ TEST_F(MatMulOpTest, OPENCLHalfUnAlignedWithBatch) { } namespace { -void Quant(const std::vector &batch, - const index_t height, - const index_t channels, - const index_t out_width, - const bool transpose_a, - const bool transpose_b) { +void QuantOutputUint8(const std::vector &batch, + const index_t height, + const index_t channels, + const index_t out_width, + const bool transpose_a, + const bool transpose_b) { // Construct graph OpsTestNet net; @@ -281,6 +281,7 @@ void Quant(const std::vector &batch, .AddIntArg("transpose_b", transpose_b ? 1 : 0) .Output("QuantizedOutput") .AddIntArg("T", DT_UINT8) + .OutputType({DT_UINT8}) .Finalize(net.NewOperatorDef()); net.Setup(DeviceType::CPU); Tensor *eq_output = net.GetTensor("ExpectedQuantizedOutput"); @@ -301,26 +302,121 @@ void Quant(const std::vector &batch, ExpectTensorSimilar(*net.GetOutput("Output"), *net.GetTensor("DequantizedOutput"), 0.01); } + +void QuantOutputInt32(const std::vector &batch, + const index_t height, + const index_t channels, + const index_t out_width, + const bool transpose_a, + const bool transpose_b) { + // Construct graph + OpsTestNet net; + + // Add input data + index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1, + std::multiplies()); + if (transpose_a) { + net.AddRandomInput("A", {batch_count, channels, height}); + } else { + net.AddRandomInput("A", {batch_count, height, channels}); + } + if (transpose_b) { + net.AddRandomInput("B", {batch_count, out_width, channels}); + } else { + net.AddRandomInput("B", {batch_count, channels, out_width}); + } + + OpDefBuilder("MatMul", "MatMulTest") + .Input("A") + .AddIntArg("transpose_a", transpose_a ? 1 : 0) + .Input("B") + .AddIntArg("transpose_b", transpose_b ? 1 : 0) + .Output("Output") + .AddIntArg("T", DT_FLOAT) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + OpDefBuilder("Quantize", "QuantizeA") + .Input("A") + .Output("QuantizedA") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeB") + .Input("B") + .Output("QuantizedB") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("MatMul", "QuantizeMatMulTest") + .Input("QuantizedA") + .AddIntArg("transpose_a", transpose_a ? 1 : 0) + .Input("QuantizedB") + .AddIntArg("transpose_b", transpose_b ? 1 : 0) + .Output("QuantizedOutput") + .AddIntArg("T", DT_UINT8) + .OutputType({DT_INT32}) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_INT32) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} } // namespace -TEST_F(MatMulOpTest, Quant) { - Quant({1}, 64, 128, 32, false, false); - Quant({1}, 64, 32, 128, false, false); - Quant({2, 3}, 64, 32, 128, false, false); - Quant({1}, 64, 128, 32, false, true); - Quant({1}, 64, 32, 128, false, true); - Quant({2, 3}, 64, 32, 128, false, true); - Quant({1}, 64, 128, 32, true, false); - Quant({1}, 64, 32, 128, true, false); - Quant({2, 3}, 64, 32, 128, true, false); - Quant({1}, 64, 128, 32, true, true); - Quant({1}, 64, 32, 128, true, true); - Quant({2, 3}, 64, 32, 128, true, true); +TEST_F(MatMulOpTest, QuantOutputUint8) { + QuantOutputUint8({1}, 64, 128, 32, false, false); + QuantOutputUint8({1}, 64, 32, 128, false, false); + QuantOutputUint8({2, 3}, 64, 32, 128, false, false); + QuantOutputUint8({1}, 64, 128, 32, false, true); + QuantOutputUint8({1}, 64, 32, 128, false, true); + QuantOutputUint8({2, 3}, 64, 32, 128, false, true); + QuantOutputUint8({1}, 64, 128, 32, true, false); + QuantOutputUint8({1}, 64, 32, 128, true, false); + QuantOutputUint8({2, 3}, 64, 32, 128, true, false); + QuantOutputUint8({1}, 64, 128, 32, true, true); + QuantOutputUint8({1}, 64, 32, 128, true, true); + QuantOutputUint8({2, 3}, 64, 32, 128, true, true); + // UnAligned + QuantOutputUint8({2}, 3, 3, 3, false, false); + QuantOutputUint8({16}, 31, 61, 67, false, true); + QuantOutputUint8({31}, 31, 61, 67, true, false); + QuantOutputUint8({2, 3}, 31, 61, 67, true, true); +} + +TEST_F(MatMulOpTest, QuantOutputInt32) { + QuantOutputInt32({1}, 64, 128, 32, false, false); + QuantOutputInt32({1}, 64, 32, 128, false, false); + QuantOutputInt32({2, 3}, 64, 32, 128, false, false); + QuantOutputInt32({1}, 64, 128, 32, false, true); + QuantOutputInt32({1}, 64, 32, 128, false, true); + QuantOutputInt32({2, 3}, 64, 32, 128, false, true); + QuantOutputInt32({1}, 64, 128, 32, true, false); + QuantOutputInt32({1}, 64, 32, 128, true, false); + QuantOutputInt32({2, 3}, 64, 32, 128, true, false); + QuantOutputInt32({1}, 64, 128, 32, true, true); + QuantOutputInt32({1}, 64, 32, 128, true, true); + QuantOutputInt32({2, 3}, 64, 32, 128, true, true); // UnAligned - Quant({2}, 3, 3, 3, false, false); - Quant({16}, 31, 61, 67, false, true); - Quant({31}, 31, 61, 67, true, false); - Quant({2, 3}, 31, 61, 67, true, true); + QuantOutputInt32({2}, 3, 3, 3, false, false); + QuantOutputInt32({16}, 31, 61, 67, false, true); + QuantOutputInt32({31}, 31, 61, 67, true, false); + QuantOutputInt32({2, 3}, 31, 61, 67, true, true); } // TODO(liyin): test transpose after implementing gpu runtime diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 8ca694dc3999b509bfef4ec93457e2683f152d5a..e7dff9bdd1c648a08d842fef9630d82ce4a5c26f 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -219,6 +219,7 @@ class TransformerRule(Enum): ADD_OPENCL_INFORMATIONS = 31 FOLD_DECONV_AND_BN = 32 FOLD_SQRDIFF_MEAN = 33 + TRANSPOSE_MATMUL_WEIGHT = 34 class ConverterInterface(object): diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index fe5c02ce940e723fe62818652dcddda5e64f7131..8d3a3b6471707e696d64988e2e71d63a168cf190 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -80,6 +80,8 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.FOLD_ACTIVATION: self.fold_activation, TransformerRule.FOLD_SQRDIFF_MEAN: self.fold_squared_diff_mean, TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters, + TransformerRule.TRANSPOSE_MATMUL_WEIGHT: + self.transpose_matmul_weight, TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format, TransformerRule.ADD_IN_OUT_TENSOR_INFO: self.add_in_out_tensor_info, @@ -1258,24 +1260,24 @@ class Transformer(base_converter.ConverterInterface): if self._option.device != DeviceType.CPU.value: return False net = self._model - transpose_arg_names = [MaceKeyword.mace_transpose_a_str, - MaceKeyword.mace_transpose_b_str] for op in net.op: if op.type == MaceOp.MatMul.name: # noqa - for i in range(len(op.input)): - input = op.input[i] - if input in self._consts \ - and len(self._consts[input].dims) == 2: - arg = ConverterUtil.get_arg(op, transpose_arg_names[i]) - if arg is not None and arg.i == 1: - six.print_('convert matmul') - filter = self._consts[input] - filter_data = np.array(filter.float_data).reshape( - filter.dims) - filter_data = filter_data.transpose(1, 0) - filter.float_data[:] = filter_data.flat - filter.dims[:] = filter_data.shape - arg.i = 0 + rhs = op.input[1] + if rhs in self._consts and len(self._consts[rhs].dims) == 2: + arg = ConverterUtil.get_arg(op, MaceKeyword.mace_transpose_b_str) # noqa + six.print_('transpose matmul weight') + if arg is None: + arg = op.arg.add() + arg.name = MaceKeyword.mace_transpose_b_str + arg.i = 0 + if arg.i == 0: + filter = self._consts[rhs] + filter_data = np.array(filter.float_data).reshape( + filter.dims) + filter_data = filter_data.transpose(1, 0) + filter.float_data[:] = filter_data.flat + filter.dims[:] = filter_data.shape + arg.i = 1 def transpose_filters(self): net = self._model @@ -1373,8 +1375,6 @@ class Transformer(base_converter.ConverterInterface): filter.dims[:] = filter_data.shape transposed_deconv_filter.add(op.input[1]) - self.transpose_matmul_weight() - return False def buffer_transform(self, op, input_idx, input_type):