未验证 提交 dd3db3d8 编写于 作者: R rsun-bdti 提交者: GitHub

Port TFL kernel Gather to TFL Micro (#97)

* Port TFL kernel Gather to TFL Micro

* Remove TODO comment in micro/kernel/gather.cc
Co-authored-by: NAdvait Jain <advaitjain@users.noreply.github.com>
上级 962c8ada
......@@ -287,6 +287,7 @@ cc_library(
"floor.cc",
"floor_div.cc",
"floor_mod.cc",
"gather.cc",
"gather_nd.cc",
"if.cc",
"l2norm.cc",
......@@ -753,6 +754,21 @@ cc_test(
],
)
cc_test(
name = "gather_test",
srcs = [
"gather_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:micro_utils",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_test(
name = "gather_nd_test",
srcs = [
......
......@@ -12,26 +12,90 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace gather {
namespace {
constexpr int kInputTensor = 0;
constexpr int kInputPositions = 1;
constexpr int kOutputTensor = 0;
template <typename InputT, typename CoordsT = int32_t>
TfLiteStatus Gather(const TfLiteGatherParams* params,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* coords, TfLiteEvalTensor* output) {
const InputT* input_data = tflite::micro::GetTensorData<InputT>(input);
const CoordsT* coords_data = tflite::micro::GetTensorData<CoordsT>(coords);
InputT* output_data = tflite::micro::GetTensorData<InputT>(output);
const TfLiteIntArray* input_dims = input->dims;
const int input_dims_size = input_dims->size;
int axis = params->axis;
if (axis < 0) {
axis += input_dims_size;
}
TFLITE_DCHECK_GE(axis, 0);
TFLITE_DCHECK_LT(axis, input_dims_size);
int batch_dims = params->batch_dims;
// batch_dims should be in range: [-rank(coords), rank(coords)].
// Negative batch_dims is added with rank of coords.
const TfLiteIntArray* coords_dims = coords->dims;
const int coords_dims_size = coords_dims->size;
if (batch_dims < 0) {
batch_dims += coords_dims_size;
}
TFLITE_DCHECK_GE(batch_dims, 0);
TFLITE_DCHECK_LT(batch_dims, input_dims_size);
TFLITE_DCHECK_LE(batch_dims, coords_dims_size);
TFLITE_DCHECK_GE(axis, batch_dims);
for (int i = 0; i < batch_dims; ++i) {
TFLITE_DCHECK_EQ(input_dims->data[i], coords_dims->data[i]);
}
const int axis_size = input_dims->data[axis];
int batch_size = 1;
for (int i = 0; i < batch_dims; ++i) {
batch_size *= input_dims->data[i];
}
int outer_size = 1;
for (int i = batch_dims; i < axis; ++i) {
outer_size *= input_dims->data[i];
}
int inner_size = 1;
for (int i = axis + 1; i < input_dims_size; ++i) {
inner_size *= input_dims->data[i];
}
int coord_size = 1;
for (int i = batch_dims; i < coords_dims_size; ++i) {
coord_size *= coords_dims->data[i];
}
for (int batch = 0; batch < batch_size; ++batch) {
for (int outer = 0; outer < outer_size; ++outer) {
for (int coord = 0; coord < coord_size; ++coord) {
TFLITE_DCHECK_GE(coords_data[coord], 0);
TFLITE_DCHECK_LT(coords_data[coord], axis_size);
std::memcpy(output_data +
(((batch * outer_size) + outer) * coord_size + coord) *
inner_size,
input_data + (((batch * outer_size) + outer) * axis_size +
coords_data[batch * coord_size + coord]) *
inner_size,
sizeof(InputT) * inner_size);
}
}
}
return kTfLiteOk;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
......@@ -40,22 +104,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* positions;
const TfLiteTensor* coords;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPositions, &positions));
GetInputSafe(context, node, kInputPositions, &coords));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (positions->type) {
case kTfLiteInt64:
switch (coords->type) {
case kTfLiteInt32:
break;
default:
TF_LITE_KERNEL_LOG(context,
"Positions of type '%s' are not supported by gather.",
TfLiteTypeGetName(positions->type));
TfLiteTypeGetName(coords->type));
return kTfLiteError;
break;
}
// Assign to output the input type.
......@@ -64,21 +127,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check conditions for different types.
switch (input->type) {
case kTfLiteFloat32:
case kTfLiteUInt8:
case kTfLiteInt8:
case kTfLiteInt16:
case kTfLiteInt64:
case kTfLiteInt32:
case kTfLiteBool:
break;
case kTfLiteString: {
// Only 1D input is supported.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
} break;
default:
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
break;
}
int axis = params->axis;
......@@ -87,126 +142,81 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
const int num_dimensions =
NumDimensions(input) + NumDimensions(positions) - 1;
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
int batch_dims = params->batch_dims;
// batch_dims should be in range: [-rank(coords), rank(coords)].
// Negative batch_dims is added with rank of coords.
if (batch_dims < 0) {
batch_dims += NumDimensions(coords);
}
TF_LITE_ENSURE(context, batch_dims <= axis);
TF_LITE_ENSURE(context, 0 <= batch_dims && batch_dims < NumDimensions(input));
TF_LITE_ENSURE(context, batch_dims <= NumDimensions(coords));
for (int i = 0; i < batch_dims; ++i) {
TF_LITE_ENSURE_EQ(context, input->dims->data[i], coords->dims->data[i]);
}
// GATHER updates the output tensor dimensions, but TfLiteTensor in the
// MicroInterpreter is a temporary allocation. We must therefore relocate the
// dims from the FlatBuffer to the persistant storage arena.
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));
TfLiteIntArray* output_shape = output->dims;
output_shape->size =
NumDimensions(input) + NumDimensions(coords) - 1 - batch_dims;
int output_index = 0;
for (int i = 0; i < axis; ++i) {
output_shape->data[output_index++] = input->dims->data[i];
}
for (int i = 0; i < positions->dims->size; ++i) {
output_shape->data[output_index++] = positions->dims->data[i];
for (int i = batch_dims; i < coords->dims->size; ++i) {
output_shape->data[output_index++] = coords->dims->data[i];
}
for (int i = axis + 1; i < input->dims->size; ++i) {
output_shape->data[output_index++] = input->dims->data[i];
}
return context->ResizeTensor(context, output, output_shape);
}
template <typename InputT, typename PositionsT>
TfLiteStatus Gather(const TfLiteGatherParams& params, const TfLiteTensor* input,
const TfLiteTensor* positions, TfLiteTensor* output) {
tflite::GatherParams op_params;
op_params.axis = params.axis;
optimized_ops::Gather(op_params, GetTensorShape(input),
GetTensorData<InputT>(input), GetTensorShape(positions),
GetTensorData<PositionsT>(positions),
GetTensorShape(output), GetTensorData<InputT>(output));
return kTfLiteOk;
}
template <typename PositionT>
TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
const TfLiteTensor* positions,
TfLiteTensor* output) {
DynamicBuffer buffer;
const PositionT* indexes = GetTensorData<PositionT>(positions);
const PositionT num_strings = GetStringCount(input);
const int num_indexes = NumElements(positions);
for (int i = 0; i < num_indexes; ++i) {
const PositionT pos = indexes[i];
TF_LITE_ENSURE(context, pos < num_strings);
const auto string_ref = GetString(input, pos);
buffer.AddString(string_ref.str, string_ref.len);
}
buffer.WriteToTensor(output, /*new_shape=*/nullptr);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* positions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPositions, &positions));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (positions->type == kTfLiteInt32) {
switch (input->type) {
case kTfLiteFloat32:
return Gather<float, int32_t>(*params, input, positions, output);
case kTfLiteUInt8:
return Gather<uint8_t, int32_t>(*params, input, positions, output);
case kTfLiteInt8:
return Gather<int8_t, int32_t>(*params, input, positions, output);
case kTfLiteInt16:
return Gather<int16_t, int32_t>(*params, input, positions, output);
case kTfLiteInt32:
return Gather<int32_t, int32_t>(*params, input, positions, output);
case kTfLiteInt64:
return Gather<int64_t, int32_t>(*params, input, positions, output);
case kTfLiteBool:
return Gather<bool, int32_t>(*params, input, positions, output);
case kTfLiteString:
return GatherStrings<int32_t>(context, input, positions, output);
default:
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
if (positions->type == kTfLiteInt64) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* coords =
tflite::micro::GetEvalInput(context, node, kInputPositions);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
if (coords->type == kTfLiteInt32) {
switch (input->type) {
case kTfLiteFloat32:
return Gather<float, int64_t>(*params, input, positions, output);
case kTfLiteUInt8:
return Gather<uint8_t, int64_t>(*params, input, positions, output);
return Gather<float, int32_t>(params, input, coords, output);
break;
case kTfLiteInt8:
return Gather<int8_t, int64_t>(*params, input, positions, output);
case kTfLiteInt16:
return Gather<int16_t, int64_t>(*params, input, positions, output);
case kTfLiteInt32:
return Gather<int32_t, int64_t>(*params, input, positions, output);
case kTfLiteInt64:
return Gather<int64_t, int64_t>(*params, input, positions, output);
case kTfLiteBool:
return Gather<bool, int64_t>(*params, input, positions, output);
case kTfLiteString:
return GatherStrings<int64_t>(context, input, positions, output);
return Gather<int8_t, int32_t>(params, input, coords, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
break;
}
}
TF_LITE_KERNEL_LOG(context,
"Positions of type '%s' are not supported by gather.",
TfLiteTypeGetName(positions->type));
return kTfLiteError;
return kTfLiteOk;
}
} // namespace gather
TfLiteRegistration* Register_GATHER() {
static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare,
gather::Eval};
return &r;
} // namespace
TfLiteRegistration Register_GATHER() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace builtin
} // namespace ops
} // namespace tflite
此差异已折叠。
......@@ -45,6 +45,7 @@ TfLiteRegistration Register_EXPAND_DIMS();
TfLiteRegistration Register_FILL();
TfLiteRegistration Register_FLOOR_DIV();
TfLiteRegistration Register_FLOOR_MOD();
TfLiteRegistration Register_GATHER();
TfLiteRegistration Register_GATHER_ND();
TfLiteRegistration Register_IF();
TfLiteRegistration Register_L2_POOL_2D();
......
......@@ -255,6 +255,11 @@ class MicroMutableOpResolver : public MicroOpResolver {
ParseFullyConnected);
}
TfLiteStatus AddGather() {
return AddBuiltin(BuiltinOperator_GATHER, tflite::Register_GATHER(),
ParseGather);
}
TfLiteStatus AddGatherNd() {
return AddBuiltin(BuiltinOperator_GATHER_ND, tflite::Register_GATHER_ND(),
ParseGatherNd);
......
......@@ -282,6 +282,7 @@ tensorflow/lite/micro/kernels/floor_test.cc \
tensorflow/lite/micro/kernels/floor_div_test.cc \
tensorflow/lite/micro/kernels/floor_mod_test.cc \
tensorflow/lite/micro/kernels/fully_connected_test.cc \
tensorflow/lite/micro/kernels/gather_test.cc \
tensorflow/lite/micro/kernels/gather_nd_test.cc \
tensorflow/lite/micro/kernels/hard_swish_test.cc \
tensorflow/lite/micro/kernels/l2norm_test.cc \
......@@ -351,6 +352,7 @@ tensorflow/lite/micro/kernels/floor_div.cc \
tensorflow/lite/micro/kernels/floor_mod.cc \
tensorflow/lite/micro/kernels/fully_connected.cc \
tensorflow/lite/micro/kernels/fully_connected_common.cc \
tensorflow/lite/micro/kernels/gather.cc \
tensorflow/lite/micro/kernels/gather_nd.cc \
tensorflow/lite/micro/kernels/hard_swish.cc \
tensorflow/lite/micro/kernels/kernel_runner.cc \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册