未验证 提交 4b1b9bd4 编写于 作者: T TFLM-bot 提交者: GitHub

Sync from upstream TF. (#291)

上级 c70c70ec
......@@ -21,7 +21,7 @@ limitations under the License.
namespace tflite {
namespace reference_integer_ops {
inline void AveragePool(const PoolParams& params,
inline bool AveragePool(const PoolParams& params,
const RuntimeShape& input_shape,
const int8_t* input_data,
const RuntimeShape& output_shape, int8_t* output_data) {
......@@ -66,6 +66,7 @@ inline void AveragePool(const PoolParams& params,
filter_count++;
}
}
if (filter_count == 0) return false;
// Round to the closest integer value.
acc = acc > 0 ? (acc + filter_count / 2) / filter_count
: (acc - filter_count / 2) / filter_count;
......@@ -77,6 +78,7 @@ inline void AveragePool(const PoolParams& params,
}
}
}
return true;
}
inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
......@@ -136,7 +138,7 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
}
}
inline void AveragePool(const PoolParams& params,
inline bool AveragePool(const PoolParams& params,
const RuntimeShape& input_shape,
const int16_t* input_data,
const RuntimeShape& output_shape,
......@@ -182,6 +184,7 @@ inline void AveragePool(const PoolParams& params,
filter_count++;
}
}
if (filter_count == 0) return false;
// Round to the closest integer value.
acc = acc > 0 ? (acc + filter_count / 2) / filter_count
: (acc - filter_count / 2) / filter_count;
......@@ -193,6 +196,7 @@ inline void AveragePool(const PoolParams& params,
}
}
}
return true;
}
inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
......
......@@ -23,7 +23,7 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
inline void AveragePool(const PoolParams& params,
inline bool AveragePool(const PoolParams& params,
const RuntimeShape& input_shape,
const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
......@@ -66,6 +66,7 @@ inline void AveragePool(const PoolParams& params,
filter_count++;
}
}
if (filter_count == 0) return false;
const float average = total / filter_count;
output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
ActivationFunctionWithMinMax(average, params.float_activation_min,
......@@ -74,9 +75,10 @@ inline void AveragePool(const PoolParams& params,
}
}
}
return true;
}
inline void AveragePool(const PoolParams& params,
inline bool AveragePool(const PoolParams& params,
const RuntimeShape& input_shape,
const uint8_t* input_data,
const RuntimeShape& output_shape,
......@@ -122,6 +124,7 @@ inline void AveragePool(const PoolParams& params,
filter_count++;
}
}
if (filter_count == 0) return false;
acc = (acc + filter_count / 2) / filter_count;
acc = std::max(acc, params.quantized_activation_min);
acc = std::min(acc, params.quantized_activation_max);
......@@ -131,6 +134,7 @@ inline void AveragePool(const PoolParams& params,
}
}
}
return true;
}
inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
......
......@@ -119,6 +119,7 @@ TfLiteStatus GetInputSafe(const TfLiteContext* context, const TfLiteNode* node,
TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
int index) {
TfLiteTensor* tensor = GetMutableInput(context, node, index);
if (tensor == nullptr) return nullptr;
return tensor->is_variable ? tensor : nullptr;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册