提交 9d8bc6b2 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Multi-threading reduce_mean quantized using gemmlowp::tasks & optimized using neon.

PiperOrigin-RevId: 235485786
上级 e9a09aaf
......@@ -20,6 +20,7 @@ limitations under the License.
#include <sys/types.h>
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <limits>
#include <memory>
#include <tuple>
......@@ -76,7 +77,6 @@ using reference_ops::Less;
using reference_ops::LessEqual;
using reference_ops::LessEqualWithScaling;
using reference_ops::LessWithScaling;
using reference_ops::Mean;
using reference_ops::ProcessBroadcastShapes;
using reference_ops::RankOneSelect;
using reference_ops::Relu1;
......@@ -1743,6 +1743,221 @@ inline void ShuffledFullyConnected(
gemm_context->workers_pool()->Execute(tasks);
}
inline void MeanImpl(const tflite::MeanParams& op_params,
const RuntimeShape& input_shape, const uint8_t* input_data,
int32 input_zero_point, float input_scale,
const RuntimeShape& output_shape, uint8_t* output_data,
int32 output_zero_point, float output_scale,
int start_depth, int end_depth) {
gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8/MeanImpl");
// Current implementation only supports dimension equals 4 and simultaneous
// reduction over width and height.
const int output_batch = output_shape.Dims(0);
const int output_height = output_shape.Dims(2);
const int output_width = output_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const float num_elements_in_axis = input_width * input_height;
TFLITE_DCHECK_EQ(op_params.axis_count, 2);
TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
const bool ordinary_mean =
(input_zero_point == output_zero_point && input_scale == output_scale);
float scale, bias;
if (!ordinary_mean) {
scale = input_scale / output_scale;
bias = -input_zero_point * scale + 0.5;
}
#ifdef USE_NEON
const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis);
// This is only an approximation as NEON does not offer division instruction.
const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup);
const float32x4_t kRounding = vdupq_n_f32(0.5);
float32x4_t bias_dup;
float32x4_t output_zero_point_dup;
if (!ordinary_mean) {
bias_dup = vdupq_n_f32(bias);
output_zero_point_dup = vdupq_n_f32(output_zero_point);
}
#endif
for (int out_b = 0; out_b < output_batch; ++out_b) {
int out_d = start_depth;
#ifdef USE_NEON
for (; out_d < end_depth - 8; out_d += 8) {
float32x4_t temp_sum_1 = vdupq_n_f32(0);
float32x4_t temp_sum_2 = vdupq_n_f32(0);
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
const uint8_t* input_data_ptr =
input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
uint8x8_t input_data_val = vld1_u8(input_data_ptr);
int16x8_t input_data_val_shift =
vreinterpretq_s16_u16(vmovl_u8(input_data_val));
float32x4_t input_float_1 =
vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift)));
float32x4_t input_float_2 =
vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift)));
temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1);
temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2);
}
}
float32x4_t mean_1 = vmulq_f32(temp_sum_1, num_elements_reverse);
float32x4_t mean_2 = vmulq_f32(temp_sum_2, num_elements_reverse);
if (!ordinary_mean) {
// maq is not supported, break down into two ops.
mean_1 = vmulq_n_f32(mean_1, scale);
mean_1 = vaddq_f32(mean_1, bias_dup);
mean_2 = vmulq_n_f32(mean_2, scale);
mean_2 = vaddq_f32(mean_2, bias_dup);
}
if (!ordinary_mean) {
mean_1 = vaddq_f32(mean_1, output_zero_point_dup);
mean_2 = vaddq_f32(mean_2, output_zero_point_dup);
}
// Rounding.
mean_1 = vaddq_f32(mean_1, kRounding);
mean_2 = vaddq_f32(mean_2, kRounding);
uint32x4_t casted_mean_1 = vcvtq_u32_f32(mean_1);
uint16x4_t narrow_range_mean_1 = vmovn_u32(casted_mean_1);
uint32x4_t casted_mean_2 = vcvtq_u32_f32(mean_2);
uint16x4_t narrow_range_mean_2 = vmovn_u32(casted_mean_2);
uint16x8_t combined_mean =
vcombine_u16(narrow_range_mean_2, narrow_range_mean_1);
uint8x8_t narrowed_combined_mean = vmovn_u16(combined_mean);
uint8_t* output_data_ptr =
output_data + Offset(output_shape, out_b, 0, 0, out_d);
vst1_u8(output_data_ptr, narrowed_combined_mean);
}
#endif
for (; out_d < end_depth; ++out_d) {
float temp_value = 0;
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
temp_value +=
input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
}
}
temp_value = temp_value / num_elements_in_axis;
if (ordinary_mean) {
output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
static_cast<uint8_t>(round(temp_value));
} else {
output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
static_cast<uint8_t>(round(temp_value * scale + bias)) +
output_zero_point;
}
}
}
}
struct MeanWorkerTask : public gemmlowp::Task {
MeanWorkerTask(const tflite::MeanParams& op_params,
const RuntimeShape& input_shape, const uint8_t* input_data,
int32 input_zero_point, float input_scale,
const RuntimeShape& output_shape, uint8_t* output_data,
int32 output_zero_point, float output_scale, int start_height,
int end_height)
: op_params_(op_params),
input_shape_(input_shape),
input_data_(input_data),
input_zero_point_(input_zero_point),
input_scale_(input_scale),
output_shape_(output_shape),
output_data_(output_data),
output_zero_point_(output_zero_point),
output_scale_(output_scale),
start_height_(start_height),
end_height_(end_height) {}
void Run() override {
MeanImpl(op_params_, input_shape_, input_data_, input_zero_point_,
input_scale_, output_shape_, output_data_, output_zero_point_,
output_scale_, start_height_, end_height_);
}
private:
const tflite::MeanParams& op_params_;
const RuntimeShape& input_shape_;
const uint8_t* input_data_;
int32 input_zero_point_;
float input_scale_;
const RuntimeShape& output_shape_;
uint8_t* output_data_;
int32 output_zero_point_;
float output_scale_;
int start_height_;
int end_height_;
gemmlowp::GemmContext* gemm_context_;
};
inline void Mean(const tflite::MeanParams& op_params,
const RuntimeShape& unextended_input_shape,
const uint8_t* input_data, int32 input_zero_point,
float input_scale, const RuntimeShape& unextended_output_shape,
uint8_t* output_data, int32 output_zero_point,
float output_scale, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8");
// Current implementation only supports dimension equals 4 and simultaneous
// reduction over width and height.
TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_depth = output_shape.Dims(3);
TFLITE_DCHECK_EQ(op_params.axis_count, 2);
TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
constexpr int kMinDepthPerThread = 8;
int thread_count = output_depth / kMinDepthPerThread;
thread_count = thread_count > 0 ? thread_count : 1;
const int capped_thread_count =
std::min(thread_count, gemm_context->max_num_threads());
if (thread_count == 1) {
MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
output_shape, output_data, output_zero_point, output_scale, 0,
output_depth);
} else {
// Instead parrallel for batch, we loop for the output_depth since batch
// is typical 1.
std::vector<gemmlowp::Task*> tasks(capped_thread_count);
int depth_start = 0;
for (int i = 0; i < capped_thread_count; ++i) {
// Try to distribute the tasks as even as possible.
int depth_end = (output_depth - depth_start) / (capped_thread_count - i);
tasks[i] = new MeanWorkerTask(op_params, input_shape, input_data,
input_zero_point, input_scale, output_shape,
output_data, output_zero_point,
output_scale, depth_start, depth_end);
depth_start = depth_end;
}
gemm_context->workers_pool()->Execute(tasks);
}
}
template <typename T>
inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
int h, int b, int kheight, int kwidth,
......
......@@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/gemm_support.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
......@@ -49,6 +51,7 @@ struct OpContext {
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
gemm_support::IncrementUsageCounter(context);
// Creates two temp tensors to store index and axis for internal
// implementation only.
auto* scratch_tensor_index = new int;
......@@ -57,6 +60,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
}
void Free(TfLiteContext* context, void* buffer) {
gemm_support::DecrementUsageCounter(context);
delete reinterpret_cast<int*>(buffer);
}
......@@ -248,6 +252,7 @@ void ResolveAxis(const int* axis_data, int axis_count,
template <KernelType kernel_type>
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
......@@ -272,13 +277,15 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
if (op_context.input->type == kTfLiteUInt8) {
reference_ops::Mean(
gemmlowp::GemmContext* gemm_context =
gemm_support::GetFromContext(context);
optimized_ops::Mean(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
op_context.input->params.zero_point, op_context.input->params.scale,
GetTensorShape(op_context.output),
GetTensorData<uint8_t>(op_context.output),
op_context.output->params.zero_point,
op_context.output->params.scale);
op_context.output->params.scale, gemm_context);
} else {
reference_ops::Mean(op_params, GetTensorShape(input),
GetTensorData<float>(input),
......
......@@ -259,7 +259,7 @@ TEST(ConstFloatMeanOpTest, KeepDims) {
// Uses a set of reduction conditions that trigger the specialized 4D version
// of Mean.
TEST(ConstFloatMeanOpTest, KeepDims_4DMean) {
TEST(ConstFloatMeanOpTest, KeepDims4DMean) {
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
......@@ -272,7 +272,7 @@ TEST(ConstFloatMeanOpTest, KeepDims_4DMean) {
ElementsAreArray(ArrayFloatNear({6, 7, 18, 19})));
}
TEST(ConstFloatMeanOpTest, KeepDims_4DMean_UInt8) {
TEST(ConstFloatMeanOpTest, KeepDims4DMeanUInt8) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.1, 0.2,
0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
......@@ -286,7 +286,24 @@ TEST(ConstFloatMeanOpTest, KeepDims_4DMean_UInt8) {
kQuantizedTolerance)));
}
TEST(ConstFloatMeanOpTest, KeepDims_4DMean_Quantized) {
TEST(ConstFloatMeanOpTest, KeepDims4DMeanLargeDepthUInt8) {
float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9,
0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3,
0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
MeanOpConstModel m({TensorType_UINT8, {1, 2, 2, 9}, -1.0, 1.0},
{TensorType_UINT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true);
m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9}));
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425},
kQuantizedTolerance)));
}
TEST(ConstFloatMeanOpTest, KeepDims4DMeanQuantized) {
float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.1, 0.2,
0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册