diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h index 936b5e36564c25d1122847fe055a463d1ce26de2..f5aa606b2a671f19749c1df98a479311b300d73b 100644 --- a/mace/kernels/pooling.h +++ b/mace/kernels/pooling.h @@ -24,6 +24,10 @@ #include "mace/core/tensor.h" #include "mace/kernels/conv_pool_2d_util.h" +#if defined(MACE_ENABLE_NEON) +#include +#endif + #ifdef MACE_ENABLE_OPENCL #include "mace/core/runtime/opencl/cl2_header.h" #endif // MACE_ENABLE_OPENCL @@ -167,9 +171,9 @@ struct PoolingFunctor: PoolingFunctorBase { } } - MaceStatus operator()(const Tensor *input_tensor, - Tensor *output_tensor, - StatsFuture *future) { + MaceStatus operator()(const Tensor *input_tensor, // NCHW + Tensor *output_tensor, // NCHW + StatsFuture *future) { MACE_UNUSED(future); std::vector output_shape(4); std::vector filter_shape = { @@ -225,6 +229,217 @@ struct PoolingFunctor: PoolingFunctorBase { } }; +template <> +struct PoolingFunctor: PoolingFunctorBase { + PoolingFunctor(const PoolingType pooling_type, + const int *kernels, + const int *strides, + const Padding padding_type, + const std::vector &paddings, + const int *dilations) + : PoolingFunctorBase( + pooling_type, kernels, strides, padding_type, paddings, dilations) { + } + + void MaxPooling(const uint8_t *input, + const index_t *in_shape, + const index_t *out_shape, + const int *filter_hw, + const int *stride_hw, + const int *pad_hw, + uint8_t *output) { +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < out_shape[0]; ++b) { + for (index_t h = 0; h < out_shape[1]; ++h) { + for (index_t w = 0; w < out_shape[2]; ++w) { + const index_t out_height = out_shape[1]; + const index_t out_width = out_shape[2]; + const index_t channels = out_shape[3]; + const index_t in_height = in_shape[1]; + const index_t in_width = in_shape[2]; + const index_t in_h_base = h * stride_hw[0] - pad_hw[0]; + const index_t in_w_base = w * stride_hw[1] - pad_hw[1]; + const index_t in_h_begin = std::max(0, in_h_base); + const index_t in_w_begin = std::max(0, in_w_base); + const index_t in_h_end = + std::min(in_height, in_h_base + filter_hw[0]); + const index_t in_w_end = + std::min(in_width, in_w_base + filter_hw[1]); + + uint8_t *out_ptr = + output + ((b * out_height + h) * out_width + w) * channels; + for (index_t ih = in_h_begin; ih < in_h_end; ++ih) { + for (index_t iw = in_w_begin; iw < in_w_end; ++iw) { + const uint8_t *in_ptr = input + + ((b * in_height + ih) * in_width + iw) * channels; + index_t c = 0; +#if defined(MACE_ENABLE_NEON) + for (; c <= channels - 16; c += 16) { + uint8x16_t out_vec = vld1q_u8(out_ptr + c); + uint8x16_t in_vec = vld1q_u8(in_ptr + c); + out_vec = vmaxq_u8(out_vec, in_vec); + vst1q_u8(out_ptr + c, out_vec); + } + for (; c <= channels - 8; c += 8) { + uint8x8_t out_vec = vld1_u8(out_ptr + c); + uint8x8_t in_vec = vld1_u8(in_ptr + c); + out_vec = vmax_u8(out_vec, in_vec); + vst1_u8(out_ptr + c, out_vec); + } +#endif + for (; c < channels; ++c) { + out_ptr[c] = std::max(out_ptr[c], in_ptr[c]); + } + } + } + } + } + } + } + + void AvgPooling(const uint8_t *input, + const index_t *in_shape, + const index_t *out_shape, + const int *filter_hw, + const int *stride_hw, + const int *pad_hw, + uint8_t *output) { +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < out_shape[0]; ++b) { + for (index_t h = 0; h < out_shape[1]; ++h) { + for (index_t w = 0; w < out_shape[2]; ++w) { + const index_t out_height = out_shape[1]; + const index_t out_width = out_shape[2]; + const index_t channels = out_shape[3]; + const index_t in_height = in_shape[1]; + const index_t in_width = in_shape[2]; + const index_t in_h_base = h * stride_hw[0] - pad_hw[0]; + const index_t in_w_base = w * stride_hw[1] - pad_hw[1]; + const index_t in_h_begin = std::max(0, in_h_base); + const index_t in_w_begin = std::max(0, in_w_base); + const index_t in_h_end = + std::min(in_height, in_h_base + filter_hw[0]); + const index_t in_w_end = + std::min(in_width, in_w_base + filter_hw[1]); + const index_t block_size = + (in_h_end - in_h_begin) * (in_w_end - in_w_begin); + MACE_CHECK(block_size > 0); + + std::vector average_buffer(channels); + uint16_t *avg_buffer = average_buffer.data(); + std::fill_n(avg_buffer, channels, 0); + for (index_t ih = in_h_begin; ih < in_h_end; ++ih) { + for (index_t iw = in_w_begin; iw < in_w_end; ++iw) { + const uint8_t *in_ptr = input + + ((b * in_height + ih) * in_width + iw) * channels; + index_t c = 0; +#if defined(MACE_ENABLE_NEON) + for (; c <= channels - 16; c += 16) { + uint16x8_t avg_vec[2]; + avg_vec[0] = vld1q_u16(avg_buffer + c); + avg_vec[1] = vld1q_u16(avg_buffer + c + 8); + uint8x16_t in_vec = vld1q_u8(in_ptr + c); + avg_vec[0] = vaddw_u8(avg_vec[0], vget_low_u8(in_vec)); + avg_vec[1] = vaddw_u8(avg_vec[1], vget_high_u8(in_vec)); + vst1q_u16(avg_buffer + c, avg_vec[0]); + vst1q_u16(avg_buffer + c + 8, avg_vec[1]); + } + for (; c <= channels - 8; c += 8) { + uint16x8_t avg_vec = vld1q_u16(avg_buffer + c); + uint8x8_t in_vec = vld1_u8(in_ptr + c); + avg_vec = vaddw_u8(avg_vec, in_vec); + vst1q_u16(avg_buffer + c, avg_vec); + } +#endif + for (; c < channels; ++c) { + avg_buffer[c] += in_ptr[c]; + } + } + } + uint8_t *out_ptr = + output + ((b * out_height + h) * out_width + w) * channels; + for (index_t c = 0; c < channels; ++c) { + out_ptr[c] = static_cast( + (avg_buffer[c] + block_size / 2) / block_size); + } + } + } + } + } + + MaceStatus operator()(const Tensor *input_tensor, // NHWC + Tensor *output_tensor, // NHWC + StatsFuture *future) { + MACE_UNUSED(future); + MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1, + "Quantized pooling does not support dilation > 1 yet."); + MACE_CHECK(input_tensor->scale() == output_tensor->scale(), + "Quantized pooling's input and output scale are not equal."); + MACE_CHECK(input_tensor->zero_point() == output_tensor->zero_point(), + "Quantized pooling's input and output zero_point are not equal"); + std::vector output_shape(4); + std::vector filter_shape = { + input_tensor->dim(3), kernels_[0], kernels_[1], input_tensor->dim(3)}; + + std::vector paddings(2); + if (paddings_.empty()) { + CalcPaddingAndOutputSize(input_tensor->shape().data(), + NHWC, + filter_shape.data(), + OHWI, + dilations_, + strides_, + padding_type_, + output_shape.data(), + paddings.data()); + } else { + paddings = paddings_; + CalcOutputSize(input_tensor->shape().data(), + NHWC, + filter_shape.data(), + OHWI, + paddings_.data(), + dilations_, + strides_, + RoundType::CEIL, + output_shape.data()); + } + MACE_RETURN_IF_ERROR(output_tensor->Resize(output_shape)); + + const index_t out_channels = output_tensor->dim(3); + const index_t in_channels = input_tensor->dim(3); + MACE_CHECK(out_channels == in_channels); + + Tensor::MappingGuard input_guard(input_tensor); + Tensor::MappingGuard output_guard(output_tensor); + const uint8_t *input = input_tensor->data(); + uint8_t *output = output_tensor->mutable_data(); + int pad_hw[2] = {paddings[0] / 2, paddings[1] / 2}; + + if (pooling_type_ == PoolingType::MAX) { + MaxPooling(input, + input_tensor->shape().data(), + output_shape.data(), + kernels_, + strides_, + pad_hw, + output); + } else if (pooling_type_ == PoolingType::AVG) { + AvgPooling(input, + input_tensor->shape().data(), + output_shape.data(), + kernels_, + strides_, + pad_hw, + output); + } else { + MACE_NOT_IMPLEMENTED; + } + + return MACE_SUCCESS; + } +}; + #ifdef MACE_ENABLE_OPENCL template struct PoolingFunctor : PoolingFunctorBase { diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc index 0b673b51ecf1a2da0107c3ae00c4c25c07fd4f9b..b16fd2612dc64b4ef393badcefc05c806c855b74 100644 --- a/mace/ops/pooling.cc +++ b/mace/ops/pooling.cc @@ -23,6 +23,11 @@ void Register_Pooling(OperatorRegistryBase *op_registry) { .TypeConstraint("T") .Build(), PoolingOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + PoolingOp); #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling") diff --git a/mace/ops/pooling_benchmark.cc b/mace/ops/pooling_benchmark.cc index dec2b53a7acdb2b40667be124413cb3be708e74c..082ee4d0cfa710ddf38ade8a6cd3c046dc551e74 100644 --- a/mace/ops/pooling_benchmark.cc +++ b/mace/ops/pooling_benchmark.cc @@ -23,7 +23,7 @@ namespace ops { namespace test { namespace { -template +template void Pooling(int iters, int batch, int channels, @@ -39,8 +39,13 @@ void Pooling(int iters, // Add input data if (D == DeviceType::CPU) { - net.AddRandomInput("Input", - {batch, channels, height, width}); + if (DataTypeToEnum::value != DT_UINT8) { + net.AddRandomInput( + "Input", {batch, channels, height, width}); + } else { + net.AddRandomInput( + "Input", {batch, height, width, channels}); + } } else if (D == DeviceType::GPU) { net.AddRandomInput("Input", {batch, height, width, channels}); @@ -57,6 +62,7 @@ void Pooling(int iters, .AddIntsArg("strides", {stride, stride}) .AddIntArg("padding", padding) .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else if (D == DeviceType::GPU) { BufferToImage(&net, "Input", "InputImage", @@ -87,24 +93,25 @@ void Pooling(int iters, } } // namespace -#define MACE_BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \ +#define MACE_BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, TYPE, DEVICE) \ static void \ MACE_BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_\ - ##DEVICE( \ + ##TYPE##_##DEVICE( \ int iters) { \ const int64_t tot = static_cast(iters) * N * C * H * W; \ mace::testing::MaccProcessed(tot); \ - mace::testing::BytesProcessed(tot *(sizeof(float))); \ - Pooling(iters, N, C, H, W, KE, STRIDE, Padding::PA, \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + Pooling(iters, N, C, H, W, KE, STRIDE, Padding::PA, \ PoolingType::PO); \ } \ MACE_BENCHMARK( \ MACE_BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_\ - ##DEVICE) + ##TYPE##_##DEVICE) #define MACE_BM_POOLING(N, C, H, W, K, S, PA, PO) \ - MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, GPU); \ - MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); + MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, float, CPU); \ + MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, float, GPU); \ + MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, uint8_t, CPU); MACE_BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX); @@ -112,6 +119,7 @@ MACE_BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX); MACE_BM_POOLING(1, 3, 513, 513, 2, 2, SAME, MAX); MACE_BM_POOLING(1, 3, 1025, 1025, 2, 2, SAME, MAX); MACE_BM_POOLING(1, 32, 480, 640, 480, 640, VALID, AVG); +MACE_BM_POOLING(1, 1024, 7, 7, 7, 1, VALID, AVG); } // namespace test } // namespace ops diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index 8e7fbb8327c4206a09ad0b825b6279efc88be7ec..af900f58ecc2a56d368b22ce0338ae723fec7da3 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -18,6 +18,7 @@ #include "mace/kernels/pooling.h" #include "mace/ops/conv_pool_2d_base.h" #include "mace/ops/ops_test_util.h" +#include "mace/kernels/quantize.h" namespace mace { namespace ops { @@ -462,6 +463,178 @@ TEST_F(PoolingOpTest, OPENCLUnAlignedLargeKernelAvgPooling) { AvgPoolingTest({3, 31, 37, 128}, {8, 8}, {8, 8}, Padding::SAME); } +TEST_F(PoolingOpTest, QUANT_MAX_VALID) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray( + "Input", {1, 4, 4, 2}, + {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}); + + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .AddIntsArg("kernels", {2, 2}) + .AddIntsArg("strides", {2, 2}) + .AddIntArg("padding", Padding::VALID) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("pooling_type", PoolingType::MAX) + .AddIntArg("T", static_cast(DT_UINT8)) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + // Check + auto expected = + CreateTensor({1, 2, 2, 2}, {5, 21, 7, 23, 13, 29, 15, 31}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(PoolingOpTest, QUANT_MAX_SAME) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 3, 3, 1}, + {0, 1, 2, 3, 4, 5, 6, 7, 8}); + + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .AddIntsArg("kernels", {2, 2}) + .AddIntsArg("strides", {2, 2}) + .AddIntArg("padding", Padding::SAME) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("pooling_type", PoolingType::MAX) + .AddIntArg("T", static_cast(DT_UINT8)) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + // Check + auto expected = CreateTensor({1, 2, 2, 1}, {4, 5, 7, 8}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(PoolingOpTest, QUANT_AVG_VALID) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray( + "Input", {1, 4, 4, 2}, + {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}); + + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .AddIntsArg("kernels", {2, 2}) + .AddIntsArg("strides", {2, 2}) + .AddIntArg("padding", Padding::VALID) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("pooling_type", PoolingType::AVG) + .AddIntArg("T", static_cast(DT_UINT8)) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + // Check + auto expected = CreateTensor( + {1, 2, 2, 2}, {3, 19, 5, 21, 11, 27, 13, 29}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +namespace { + +void TestQuant(const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t channels, + const std::vector &kernels, + const std::vector &strides, + enum Padding padding_type, + PoolingType pooling) { + OpsTestNet net; + net.AddRandomInput( + "Input", {batch, in_height, in_width, channels}, false); + net.TransformDataFormat( + "Input", NHWC, "InputNCHW", NCHW); + + OpDefBuilder("Pooling", "PoolingTest") + .Input("InputNCHW") + .Output("OutputNCHW") + .AddIntArg("pooling_type", pooling) + .AddIntsArg("kernels", kernels) + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", DT_FLOAT) + .Finalize(net.NewOperatorDef()); + + net.RunOp(CPU); + net.TransformDataFormat( + "OutputNCHW", NCHW, "Output", NHWC); + + OpDefBuilder("Quantize", "QuantizeInput") + .Input("Input") + .Output("QuantizedInput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Pooling", "PoolingTest") + .Input("QuantizedInput") + .Output("QuantizedOutput") + .AddIntsArg("kernels", kernels) + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("pooling_type", pooling) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.Setup(DeviceType::CPU); + Tensor *q_input = net.GetTensor("QuantizedInput"); + Tensor *q_output = net.GetTensor("QuantizedOutput"); + q_output->SetScale(q_input->scale()); + q_output->SetZeroPoint(q_input->zero_point()); + net.Run(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} +} // namespace + +TEST_F(PoolingOpTest, Quant) { + TestQuant(1, 7, 7, 1024, {7, 7}, {1, 1}, Padding::VALID, PoolingType::AVG); + TestQuant(1, 3, 3, 2, {3, 3}, {1, 1}, Padding::SAME, PoolingType::AVG); + TestQuant(1, 7, 7, 1024, {7, 7}, {1, 1}, Padding::VALID, PoolingType::MAX); + TestQuant(1, 7, 7, 1024, {7, 7}, {1, 1}, Padding::SAME, PoolingType::MAX); + TestQuant(1, 7, 7, 2048, {7, 7}, {1, 1}, Padding::SAME, PoolingType::AVG); + TestQuant(3, 15, 15, 128, {4, 4}, {4, 4}, Padding::VALID, PoolingType::AVG); + TestQuant(3, 15, 15, 128, {4, 4}, {4, 4}, Padding::VALID, PoolingType::MAX); + TestQuant(3, 31, 37, 128, {2, 2}, {2, 2}, Padding::VALID, PoolingType::AVG); + TestQuant(3, 31, 37, 128, {2, 2}, {2, 2}, Padding::VALID, PoolingType::MAX); +} } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/squeeze.cc b/mace/ops/squeeze.cc index e30a87bdc5d870099d1e270b7424dac7a5974d32..eac886dd82b1c7ae515cf44e72d2bf1a1ce12508 100644 --- a/mace/ops/squeeze.cc +++ b/mace/ops/squeeze.cc @@ -23,6 +23,11 @@ void Register_Squeeze(OperatorRegistryBase *op_registry) { .TypeConstraint("T") .Build(), SqueezeOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + SqueezeOp); #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze")