pooling.cc 12.4 KB
Newer Older
1
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/pooling.h"

17
#include "Include/arm_nnfunctions.h"
18
#include "tensorflow/lite/c/builtin_op_data.h"
19
#include "tensorflow/lite/c/common.h"
20 21 22
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
23
#include "tensorflow/lite/micro/kernels/pooling.h"
24
#include "tensorflow/lite/micro/micro_log.h"
25 26 27 28 29 30

namespace tflite {

namespace {

struct OpData {
31 32
  OpDataPooling reference_op_data;

33 34 35 36
  // Index to buffer for optimizations if applicable.
  int buffer_idx;
};

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
void PopulateCommonParams(
    TfLiteContext* const context, cmsis_nn_dims* const input_dims,
    cmsis_nn_dims* const output_dims, cmsis_nn_pool_params* const pool_params,
    cmsis_nn_context* const ctx, cmsis_nn_dims* const filter_dims,
    const OpData& data, const RuntimeShape& input_shape,
    const RuntimeShape& output_shape, const TfLitePoolParams* params) {
  const int depth = MatchingDim(input_shape, 3, output_shape, 3);

  input_dims->n = 1;
  input_dims->h = input_shape.Dims(1);
  input_dims->w = input_shape.Dims(2);
  input_dims->c = depth;

  output_dims->n = 1;
  output_dims->h = output_shape.Dims(1);
  output_dims->w = output_shape.Dims(2);
  output_dims->c = depth;

  pool_params->stride.h = params->stride_height;
  pool_params->stride.w = params->stride_width;
  pool_params->padding.h = data.reference_op_data.padding.height;
  pool_params->padding.w = data.reference_op_data.padding.width;
  pool_params->activation.min = data.reference_op_data.activation_min;
  pool_params->activation.max = data.reference_op_data.activation_max;

  filter_dims->n = 1;
  filter_dims->h = params->filter_height;
  filter_dims->w = params->filter_width;
  filter_dims->c = 1;
  ctx->buf = nullptr;
  ctx->size = 0;
  if (data.buffer_idx > -1) {
    ctx->buf = context->GetScratchBuffer(context, data.buffer_idx);
  }
}

73 74 75 76
void AverageEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
                          const TfLitePoolParams* params, const OpData& data,
                          const TfLiteEvalTensor* input,
                          TfLiteEvalTensor* output) {
77
  TFLITE_DCHECK((input->type == kTfLiteInt8) || (input->type == kTfLiteInt16));
78

79
  RuntimeShape input_shape = micro::GetTensorShape(input);
80
  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
81

82
  RuntimeShape output_shape = micro::GetTensorShape(output);
83
  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
84

85 86 87 88 89 90
  cmsis_nn_dims input_dims;
  cmsis_nn_dims output_dims;
  cmsis_nn_pool_params pool_params;
  cmsis_nn_dims filter_dims;
  cmsis_nn_context ctx;

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
  PopulateCommonParams(context, &input_dims, &output_dims, &pool_params, &ctx,
                       &filter_dims, data, input_shape, output_shape, params);

  if (input->type == kTfLiteInt8) {
    TFLITE_DCHECK_EQ(
        arm_avgpool_s8(&ctx, &pool_params, &input_dims,
                       micro::GetTensorData<int8_t>(input), &filter_dims,
                       &output_dims, micro::GetTensorData<int8_t>(output)),
        ARM_CMSIS_NN_SUCCESS);
  } else {
    TFLITE_DCHECK_EQ(
        arm_avgpool_s16(&ctx, &pool_params, &input_dims,
                        micro::GetTensorData<int16_t>(input), &filter_dims,
                        &output_dims, micro::GetTensorData<int16_t>(output)),
        ARM_CMSIS_NN_SUCCESS);
  }
107 108
}

109 110 111 112 113 114
TfLiteStatus MaxEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
                              const TfLitePoolParams* params,
                              const OpData& data, const TfLiteEvalTensor* input,
                              TfLiteEvalTensor* output) {
  TFLITE_DCHECK((input->type == kTfLiteInt8) || (input->type == kTfLiteInt16));

115
  RuntimeShape input_shape = micro::GetTensorShape(input);
116 117
  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);

118
  RuntimeShape output_shape = micro::GetTensorShape(output);
119
  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
120 121 122 123 124 125 126

  cmsis_nn_dims input_dims;
  cmsis_nn_dims output_dims;
  cmsis_nn_pool_params pool_params;
  cmsis_nn_dims filter_dims;
  cmsis_nn_context ctx;

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  PopulateCommonParams(context, &input_dims, &output_dims, &pool_params, &ctx,
                       &filter_dims, data, input_shape, output_shape, params);

  if (input->type == kTfLiteInt8) {
    TFLITE_DCHECK_EQ(
        arm_max_pool_s8(&ctx, &pool_params, &input_dims,
                        micro::GetTensorData<int8_t>(input), &filter_dims,
                        &output_dims, micro::GetTensorData<int8_t>(output)),
        ARM_CMSIS_NN_SUCCESS);
  } else {
    TFLITE_DCHECK_EQ(
        arm_max_pool_s16(&ctx, &pool_params, &input_dims,
                         micro::GetTensorData<int16_t>(input), &filter_dims,
                         &output_dims, micro::GetTensorData<int16_t>(output)),
        ARM_CMSIS_NN_SUCCESS);
  }
143 144 145 146 147 148 149 150 151 152

  return kTfLiteOk;
}

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  return context->AllocatePersistentBuffer(context, sizeof(OpData));
}

TfLiteStatus MaxPrepare(TfLiteContext* context, TfLiteNode* node) {
153 154 155
  TF_LITE_ENSURE_STATUS(PoolingPrepare(context, node));
  // Set buffer index to a reset value
  static_cast<OpData*>(node->user_data)->buffer_idx = -1;
156 157 158 159
  return kTfLiteOk;
}

TfLiteStatus AveragePrepare(TfLiteContext* context, TfLiteNode* node) {
160
  TF_LITE_ENSURE_STATUS(PoolingPrepare(context, node));
161

162 163 164 165 166 167
  MicroContext* micro_context = GetMicroContext(context);

  TfLiteTensor* input =
      micro_context->AllocateTempInputTensor(node, kPoolingInputTensor);
  TfLiteTensor* output =
      micro_context->AllocateTempOutputTensor(node, kPoolingOutputTensor);
168

169
  if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
170 171 172 173 174 175 176 177 178 179
    RuntimeShape input_shape = GetTensorShape(input);
    TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);

    RuntimeShape output_shape = GetTensorShape(output);
    TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);

    const int depth = MatchingDim(input_shape, 3, output_shape, 3);
    const int output_width = output_shape.Dims(2);

    const int32_t buffer_size =
180 181 182
        input->type == kTfLiteInt16
            ? arm_avgpool_s16_get_buffer_size(output_width, depth)
            : arm_avgpool_s8_get_buffer_size(output_width, depth);
183

184
    auto* data = static_cast<OpData*>(node->user_data);
185 186 187 188 189 190 191
    if (buffer_size > 0) {
      TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
          context, buffer_size, &data->buffer_idx));
    } else {
      data->buffer_idx = -1;
    }
  }
192 193 194

  micro_context->DeallocateTempTfLiteTensor(output);
  micro_context->DeallocateTempTfLiteTensor(input);
195 196 197 198 199 200 201 202 203 204 205
  return kTfLiteOk;
}

TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->builtin_data != nullptr);
  auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);

  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData& data = *(static_cast<const OpData*>(node->user_data));

  const TfLiteEvalTensor* input =
206
      micro::GetEvalInput(context, node, kPoolingInputTensor);
207
  TfLiteEvalTensor* output =
208
      micro::GetEvalOutput(context, node, kPoolingOutputTensor);
209 210

  // Inputs and outputs share the same type, guaranteed by the converter.
211 212 213 214 215 216 217 218 219
  if (input->type == kTfLiteFloat32) {
    AveragePoolingEvalFloat(context, node, params, &data.reference_op_data,
                            input, output);
  } else if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
    AverageEvalQuantized(context, node, params, data, input, output);
  } else {
    MicroPrintf("Input type %s is not currently supported",
                TfLiteTypeGetName(input->type));
    return kTfLiteError;
220
  }
221

222 223 224
  return kTfLiteOk;
}

225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
TfLiteStatus AverageEvalInt8(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->builtin_data != nullptr);
  auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);

  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData& data = *(static_cast<const OpData*>(node->user_data));

  const TfLiteEvalTensor* input =
      micro::GetEvalInput(context, node, kPoolingInputTensor);
  TFLITE_DCHECK(input->type == kTfLiteInt8);
  TfLiteEvalTensor* output =
      micro::GetEvalOutput(context, node, kPoolingOutputTensor);

  AverageEvalQuantized(context, node, params, data, input, output);

  return kTfLiteOk;
}

243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
TfLiteStatus AverageEvalInt16(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->builtin_data != nullptr);
  auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);

  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData& data = *(static_cast<const OpData*>(node->user_data));

  const TfLiteEvalTensor* input =
      micro::GetEvalInput(context, node, kPoolingInputTensor);
  TFLITE_DCHECK(input->type == kTfLiteInt16);
  TfLiteEvalTensor* output =
      micro::GetEvalOutput(context, node, kPoolingOutputTensor);

  AverageEvalQuantized(context, node, params, data, input, output);

  return kTfLiteOk;
}
260 261 262 263 264 265 266 267
TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->builtin_data != nullptr);
  auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);

  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData& data = *(static_cast<const OpData*>(node->user_data));

  const TfLiteEvalTensor* input =
268
      micro::GetEvalInput(context, node, kPoolingInputTensor);
269
  TfLiteEvalTensor* output =
270
      micro::GetEvalOutput(context, node, kPoolingOutputTensor);
271

272 273 274 275 276 277 278 279 280
  if (input->type == kTfLiteFloat32) {
    MaxPoolingEvalFloat(context, node, params, &data.reference_op_data, input,
                        output);
  } else if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
    MaxEvalQuantized(context, node, params, data, input, output);
  } else {
    MicroPrintf("Input type %s is not currently supported",
                TfLiteTypeGetName(input->type));
    return kTfLiteError;
281
  }
282

283 284 285
  return kTfLiteOk;
}

286 287 288 289 290 291 292 293 294 295 296 297 298
TfLiteStatus MaxEvalInt8(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->builtin_data != nullptr);
  auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);

  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData& data = *(static_cast<const OpData*>(node->user_data));

  const TfLiteEvalTensor* input =
      micro::GetEvalInput(context, node, kPoolingInputTensor);
  TFLITE_DCHECK(input->type == kTfLiteInt8);
  TfLiteEvalTensor* output =
      micro::GetEvalOutput(context, node, kPoolingOutputTensor);

299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
  MaxEvalQuantized(context, node, params, data, input, output);
  return kTfLiteOk;
}

TfLiteStatus MaxEvalInt16(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->builtin_data != nullptr);
  auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);

  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData& data = *(static_cast<const OpData*>(node->user_data));

  const TfLiteEvalTensor* input =
      micro::GetEvalInput(context, node, kPoolingInputTensor);
  TFLITE_DCHECK(input->type == kTfLiteInt16);
  TfLiteEvalTensor* output =
      micro::GetEvalOutput(context, node, kPoolingOutputTensor);

  MaxEvalQuantized(context, node, params, data, input, output);
317 318 319
  return kTfLiteOk;
}

320
}  // namespace
321

322
TFLMRegistration Register_AVERAGE_POOL_2D_INT8() {
323 324 325
  return tflite::micro::RegisterOp(Init, AveragePrepare, AverageEvalInt8);
}

326
TFLMRegistration Register_AVERAGE_POOL_2D_INT16() {
327 328 329
  return tflite::micro::RegisterOp(Init, AveragePrepare, AverageEvalInt16);
}

330
TFLMRegistration Register_AVERAGE_POOL_2D() {
331
  return tflite::micro::RegisterOp(Init, AveragePrepare, AverageEval);
332 333
}

334
TFLMRegistration Register_MAX_POOL_2D_INT8() {
335 336 337
  return tflite::micro::RegisterOp(Init, MaxPrepare, MaxEvalInt8);
}

338
TFLMRegistration Register_MAX_POOL_2D_INT16() {
339 340 341
  return tflite::micro::RegisterOp(Init, MaxPrepare, MaxEvalInt16);
}

342
TFLMRegistration Register_MAX_POOL_2D() {
343
  return tflite::micro::RegisterOp(Init, MaxPrepare, MaxEval);
344 345
}

346
}  // namespace tflite