提交 a4d5cf46 编写于 作者: 李寅

Merge branch 'pooling' into 'master'

Add quantized pooling and squeeze

See merge request !734
......@@ -24,6 +24,10 @@
#include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#endif // MACE_ENABLE_OPENCL
......@@ -167,9 +171,9 @@ struct PoolingFunctor<DeviceType::CPU, float>: PoolingFunctorBase {
}
}
MaceStatus operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
MaceStatus operator()(const Tensor *input_tensor, // NCHW
Tensor *output_tensor, // NCHW
StatsFuture *future) {
MACE_UNUSED(future);
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {
......@@ -225,6 +229,217 @@ struct PoolingFunctor<DeviceType::CPU, float>: PoolingFunctorBase {
}
};
template <>
struct PoolingFunctor<DeviceType::CPU, uint8_t>: PoolingFunctorBase {
PoolingFunctor(const PoolingType pooling_type,
const int *kernels,
const int *strides,
const Padding padding_type,
const std::vector<int> &paddings,
const int *dilations)
: PoolingFunctorBase(
pooling_type, kernels, strides, padding_type, paddings, dilations) {
}
void MaxPooling(const uint8_t *input,
const index_t *in_shape,
const index_t *out_shape,
const int *filter_hw,
const int *stride_hw,
const int *pad_hw,
uint8_t *output) {
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t h = 0; h < out_shape[1]; ++h) {
for (index_t w = 0; w < out_shape[2]; ++w) {
const index_t out_height = out_shape[1];
const index_t out_width = out_shape[2];
const index_t channels = out_shape[3];
const index_t in_height = in_shape[1];
const index_t in_width = in_shape[2];
const index_t in_h_base = h * stride_hw[0] - pad_hw[0];
const index_t in_w_base = w * stride_hw[1] - pad_hw[1];
const index_t in_h_begin = std::max<index_t>(0, in_h_base);
const index_t in_w_begin = std::max<index_t>(0, in_w_base);
const index_t in_h_end =
std::min(in_height, in_h_base + filter_hw[0]);
const index_t in_w_end =
std::min(in_width, in_w_base + filter_hw[1]);
uint8_t *out_ptr =
output + ((b * out_height + h) * out_width + w) * channels;
for (index_t ih = in_h_begin; ih < in_h_end; ++ih) {
for (index_t iw = in_w_begin; iw < in_w_end; ++iw) {
const uint8_t *in_ptr = input +
((b * in_height + ih) * in_width + iw) * channels;
index_t c = 0;
#if defined(MACE_ENABLE_NEON)
for (; c <= channels - 16; c += 16) {
uint8x16_t out_vec = vld1q_u8(out_ptr + c);
uint8x16_t in_vec = vld1q_u8(in_ptr + c);
out_vec = vmaxq_u8(out_vec, in_vec);
vst1q_u8(out_ptr + c, out_vec);
}
for (; c <= channels - 8; c += 8) {
uint8x8_t out_vec = vld1_u8(out_ptr + c);
uint8x8_t in_vec = vld1_u8(in_ptr + c);
out_vec = vmax_u8(out_vec, in_vec);
vst1_u8(out_ptr + c, out_vec);
}
#endif
for (; c < channels; ++c) {
out_ptr[c] = std::max(out_ptr[c], in_ptr[c]);
}
}
}
}
}
}
}
void AvgPooling(const uint8_t *input,
const index_t *in_shape,
const index_t *out_shape,
const int *filter_hw,
const int *stride_hw,
const int *pad_hw,
uint8_t *output) {
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t h = 0; h < out_shape[1]; ++h) {
for (index_t w = 0; w < out_shape[2]; ++w) {
const index_t out_height = out_shape[1];
const index_t out_width = out_shape[2];
const index_t channels = out_shape[3];
const index_t in_height = in_shape[1];
const index_t in_width = in_shape[2];
const index_t in_h_base = h * stride_hw[0] - pad_hw[0];
const index_t in_w_base = w * stride_hw[1] - pad_hw[1];
const index_t in_h_begin = std::max<index_t>(0, in_h_base);
const index_t in_w_begin = std::max<index_t>(0, in_w_base);
const index_t in_h_end =
std::min(in_height, in_h_base + filter_hw[0]);
const index_t in_w_end =
std::min(in_width, in_w_base + filter_hw[1]);
const index_t block_size =
(in_h_end - in_h_begin) * (in_w_end - in_w_begin);
MACE_CHECK(block_size > 0);
std::vector<uint16_t> average_buffer(channels);
uint16_t *avg_buffer = average_buffer.data();
std::fill_n(avg_buffer, channels, 0);
for (index_t ih = in_h_begin; ih < in_h_end; ++ih) {
for (index_t iw = in_w_begin; iw < in_w_end; ++iw) {
const uint8_t *in_ptr = input +
((b * in_height + ih) * in_width + iw) * channels;
index_t c = 0;
#if defined(MACE_ENABLE_NEON)
for (; c <= channels - 16; c += 16) {
uint16x8_t avg_vec[2];
avg_vec[0] = vld1q_u16(avg_buffer + c);
avg_vec[1] = vld1q_u16(avg_buffer + c + 8);
uint8x16_t in_vec = vld1q_u8(in_ptr + c);
avg_vec[0] = vaddw_u8(avg_vec[0], vget_low_u8(in_vec));
avg_vec[1] = vaddw_u8(avg_vec[1], vget_high_u8(in_vec));
vst1q_u16(avg_buffer + c, avg_vec[0]);
vst1q_u16(avg_buffer + c + 8, avg_vec[1]);
}
for (; c <= channels - 8; c += 8) {
uint16x8_t avg_vec = vld1q_u16(avg_buffer + c);
uint8x8_t in_vec = vld1_u8(in_ptr + c);
avg_vec = vaddw_u8(avg_vec, in_vec);
vst1q_u16(avg_buffer + c, avg_vec);
}
#endif
for (; c < channels; ++c) {
avg_buffer[c] += in_ptr[c];
}
}
}
uint8_t *out_ptr =
output + ((b * out_height + h) * out_width + w) * channels;
for (index_t c = 0; c < channels; ++c) {
out_ptr[c] = static_cast<uint8_t>(
(avg_buffer[c] + block_size / 2) / block_size);
}
}
}
}
}
MaceStatus operator()(const Tensor *input_tensor, // NHWC
Tensor *output_tensor, // NHWC
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1,
"Quantized pooling does not support dilation > 1 yet.");
MACE_CHECK(input_tensor->scale() == output_tensor->scale(),
"Quantized pooling's input and output scale are not equal.");
MACE_CHECK(input_tensor->zero_point() == output_tensor->zero_point(),
"Quantized pooling's input and output zero_point are not equal");
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {
input_tensor->dim(3), kernels_[0], kernels_[1], input_tensor->dim(3)};
std::vector<int> paddings(2);
if (paddings_.empty()) {
CalcPaddingAndOutputSize(input_tensor->shape().data(),
NHWC,
filter_shape.data(),
OHWI,
dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(),
NHWC,
filter_shape.data(),
OHWI,
paddings_.data(),
dilations_,
strides_,
RoundType::CEIL,
output_shape.data());
}
MACE_RETURN_IF_ERROR(output_tensor->Resize(output_shape));
const index_t out_channels = output_tensor->dim(3);
const index_t in_channels = input_tensor->dim(3);
MACE_CHECK(out_channels == in_channels);
Tensor::MappingGuard input_guard(input_tensor);
Tensor::MappingGuard output_guard(output_tensor);
const uint8_t *input = input_tensor->data<uint8_t>();
uint8_t *output = output_tensor->mutable_data<uint8_t>();
int pad_hw[2] = {paddings[0] / 2, paddings[1] / 2};
if (pooling_type_ == PoolingType::MAX) {
MaxPooling(input,
input_tensor->shape().data(),
output_shape.data(),
kernels_,
strides_,
pad_hw,
output);
} else if (pooling_type_ == PoolingType::AVG) {
AvgPooling(input,
input_tensor->shape().data(),
output_shape.data(),
kernels_,
strides_,
pad_hw,
output);
} else {
MACE_NOT_IMPLEMENTED;
}
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
struct PoolingFunctor<DeviceType::GPU, T> : PoolingFunctorBase {
......
......@@ -23,6 +23,11 @@ void Register_Pooling(OperatorRegistryBase *op_registry) {
.TypeConstraint<float>("T")
.Build(),
PoolingOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
PoolingOp<DeviceType::CPU, uint8_t>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
......
......@@ -23,7 +23,7 @@ namespace ops {
namespace test {
namespace {
template<DeviceType D>
template <DeviceType D, typename T>
void Pooling(int iters,
int batch,
int channels,
......@@ -39,8 +39,13 @@ void Pooling(int iters,
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input",
{batch, channels, height, width});
if (DataTypeToEnum<T>::value != DT_UINT8) {
net.AddRandomInput<D, float>(
"Input", {batch, channels, height, width});
} else {
net.AddRandomInput<DeviceType::CPU, uint8_t>(
"Input", {batch, height, width, channels});
}
} else if (D == DeviceType::GPU) {
net.AddRandomInput<D, float>("Input",
{batch, height, width, channels});
......@@ -57,6 +62,7 @@ void Pooling(int iters,
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::GPU) {
BufferToImage<D, float>(&net, "Input", "InputImage",
......@@ -87,24 +93,25 @@ void Pooling(int iters,
}
} // namespace
#define MACE_BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \
#define MACE_BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, TYPE, DEVICE) \
static void \
MACE_BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_\
##DEVICE( \
##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(float))); \
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Pooling<DEVICE, TYPE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, \
PoolingType::PO); \
} \
MACE_BENCHMARK( \
MACE_BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_\
##DEVICE)
##TYPE##_##DEVICE)
#define MACE_BM_POOLING(N, C, H, W, K, S, PA, PO) \
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, GPU); \
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU);
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, float, CPU); \
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, float, GPU); \
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, uint8_t, CPU);
MACE_BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX);
......@@ -112,6 +119,7 @@ MACE_BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX);
MACE_BM_POOLING(1, 3, 513, 513, 2, 2, SAME, MAX);
MACE_BM_POOLING(1, 3, 1025, 1025, 2, 2, SAME, MAX);
MACE_BM_POOLING(1, 32, 480, 640, 480, 640, VALID, AVG);
MACE_BM_POOLING(1, 1024, 7, 7, 7, 1, VALID, AVG);
} // namespace test
} // namespace ops
......
......@@ -18,6 +18,7 @@
#include "mace/kernels/pooling.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/ops_test_util.h"
#include "mace/kernels/quantize.h"
namespace mace {
namespace ops {
......@@ -462,6 +463,178 @@ TEST_F(PoolingOpTest, OPENCLUnAlignedLargeKernelAvgPooling) {
AvgPoolingTest<GPU, float>({3, 31, 37, 128}, {8, 8}, {8, 8}, Padding::SAME);
}
TEST_F(PoolingOpTest, QUANT_MAX_VALID) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, uint8_t>(
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("pooling_type", PoolingType::MAX)
.AddIntArg("T", static_cast<int>(DT_UINT8))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
// Check
auto expected =
CreateTensor<uint8_t>({1, 2, 2, 2}, {5, 21, 7, 23, 13, 29, 15, 31});
ExpectTensorNear<uint8_t>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(PoolingOpTest, QUANT_MAX_SAME) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, uint8_t>("Input", {1, 3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8});
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("pooling_type", PoolingType::MAX)
.AddIntArg("T", static_cast<int>(DT_UINT8))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
// Check
auto expected = CreateTensor<uint8_t>({1, 2, 2, 1}, {4, 5, 7, 8});
ExpectTensorNear<uint8_t>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(PoolingOpTest, QUANT_AVG_VALID) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, uint8_t>(
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("pooling_type", PoolingType::AVG)
.AddIntArg("T", static_cast<int>(DT_UINT8))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
// Check
auto expected = CreateTensor<uint8_t>(
{1, 2, 2, 2}, {3, 19, 5, 21, 11, 27, 13, 29});
ExpectTensorNear<uint8_t>(*expected, *net.GetOutput("Output"), 1e-5);
}
namespace {
void TestQuant(const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const std::vector<int> &kernels,
const std::vector<int> &strides,
enum Padding padding_type,
PoolingType pooling) {
OpsTestNet net;
net.AddRandomInput<CPU, float>(
"Input", {batch, in_height, in_width, channels}, false);
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("pooling_type", pooling)
.AddIntsArg("kernels", kernels)
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding_type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", NCHW, "Output", NHWC);
OpDefBuilder("Quantize", "QuantizeInput")
.Input("Input")
.Output("QuantizedInput")
.OutputType({DT_UINT8})
.AddIntArg("T", DT_UINT8)
.AddIntArg("non_zero", true)
.Finalize(net.NewOperatorDef());
net.RunOp();
OpDefBuilder("Pooling", "PoolingTest")
.Input("QuantizedInput")
.Output("QuantizedOutput")
.AddIntsArg("kernels", kernels)
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding_type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("pooling_type", pooling)
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
net.Setup(DeviceType::CPU);
Tensor *q_input = net.GetTensor("QuantizedInput");
Tensor *q_output = net.GetTensor("QuantizedOutput");
q_output->SetScale(q_input->scale());
q_output->SetZeroPoint(q_input->zero_point());
net.Run();
OpDefBuilder("Dequantize", "DeQuantizeTest")
.Input("QuantizedOutput")
.Output("DequantizedOutput")
.OutputType({DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
net.RunOp();
// Check
ExpectTensorSimilar<float>(*net.GetOutput("Output"),
*net.GetTensor("DequantizedOutput"), 0.01);
}
} // namespace
TEST_F(PoolingOpTest, Quant) {
TestQuant(1, 7, 7, 1024, {7, 7}, {1, 1}, Padding::VALID, PoolingType::AVG);
TestQuant(1, 3, 3, 2, {3, 3}, {1, 1}, Padding::SAME, PoolingType::AVG);
TestQuant(1, 7, 7, 1024, {7, 7}, {1, 1}, Padding::VALID, PoolingType::MAX);
TestQuant(1, 7, 7, 1024, {7, 7}, {1, 1}, Padding::SAME, PoolingType::MAX);
TestQuant(1, 7, 7, 2048, {7, 7}, {1, 1}, Padding::SAME, PoolingType::AVG);
TestQuant(3, 15, 15, 128, {4, 4}, {4, 4}, Padding::VALID, PoolingType::AVG);
TestQuant(3, 15, 15, 128, {4, 4}, {4, 4}, Padding::VALID, PoolingType::MAX);
TestQuant(3, 31, 37, 128, {2, 2}, {2, 2}, Padding::VALID, PoolingType::AVG);
TestQuant(3, 31, 37, 128, {2, 2}, {2, 2}, Padding::VALID, PoolingType::MAX);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -23,6 +23,11 @@ void Register_Squeeze(OperatorRegistryBase *op_registry) {
.TypeConstraint<float>("T")
.Build(),
SqueezeOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
SqueezeOp<DeviceType::CPU, uint8_t>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册