未验证 提交 a4435d23 编写于 作者: T TFLM-bot 提交者: GitHub

Sync from upstream TF. (#238)

上级 d25c30df
......@@ -51,7 +51,7 @@ inline void Mul(const ArithmeticParams& params,
GetActivationParams(params, &output_activation_min, &output_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingExtendedShapeFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] * input2_data[i], output_activation_min,
......@@ -66,7 +66,7 @@ inline void Mul(const ArithmeticParams& params,
TFLITE_DCHECK_LE(params.quantized_activation_min,
params.quantized_activation_max);
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
MatchingExtendedShapeFlatSize(input1_shape, input2_shape, output_shape);
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
}
......
......@@ -602,6 +602,58 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
}
// Flat size calculation, checking if their extended shapes match.
inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0) {
const int shape_dims = shape.DimensionsCount();
const int check_shape_0_dims = check_shape_0.DimensionsCount();
const int min_dims = std::min(shape_dims, check_shape_0_dims);
for (int i = 0; i < min_dims; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i),
check_shape_0.Dims(check_shape_0_dims - 1 - i));
}
for (int i = min_dims; i < shape_dims; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i), 1);
}
for (int i = min_dims; i < check_shape_0_dims; ++i) {
TFLITE_DCHECK_EQ(check_shape_0.Dims(check_shape_0_dims - 1 - i), 1);
}
return shape.FlatSize();
}
inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1) {
const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1),
flat_size);
return flat_size;
}
inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1,
const RuntimeShape& check_shape_2) {
const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
TFLITE_DCHECK_EQ(
MatchingExtendedShapeFlatSize(shape, check_shape_1, check_shape_2),
flat_size);
return flat_size;
}
inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1,
const RuntimeShape& check_shape_2,
const RuntimeShape& check_shape_3) {
const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1,
check_shape_2, check_shape_3),
flat_size);
return flat_size;
}
// Data is required to be contiguous, and so many operators can use either the
// full array flat size or the flat size with one dimension skipped (commonly
// the depth).
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册