提交 8142bcde 编写于 作者: R Robert Suderman 提交者: TensorFlower Gardener

[tosa] Support zero scale in tfl-to-tosa

Sometimes scales are zero in the all zero case. This is obviously
silly but it triggers a failure in the pipeline. Make scale the
minimum scale value as whatevs.

PiperOrigin-RevId: 564556649
上级 b1bf0998
......@@ -20,6 +20,7 @@ limitations under the License.
#include <climits>
#include <cstdint>
#include <iostream>
#include <limits>
#include <memory>
#include <optional>
#include <set>
......@@ -199,25 +200,26 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
uint32_t flags =
is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
// Rejects if quantized tensors have zero scales.
// Zero scales we make the minimum fp value, this is because some flatbuffers
// contain zero scale for zero values.
llvm::SmallVector<double> scales;
for (float scale : quant_params.scale) {
if (scale == 0) {
return errors::InvalidArgument(
"Quantized tensors must have non-zero scales");
scales.push_back(std::numeric_limits<float>::min());
continue;
}
scales.push_back(scale);
}
// Scale size can't be zero as it is checked before.
if (quant_params.scale.size() != 1) {
llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
quant_params.scale.end());
return mlir::quant::UniformQuantizedPerAxisType::get(
flags, storage_type, builder.getF32Type(), scales,
quant_params.zero_point, quant_params.quantized_dimension, storage_min,
storage_max);
}
return mlir::quant::UniformQuantizedType::get(
flags, storage_type, builder.getF32Type(), quant_params.scale.at(0),
flags, storage_type, builder.getF32Type(), scales[0],
quant_params.zero_point.at(0), storage_min, storage_max);
}
......
......@@ -37,11 +37,14 @@ limitations under the License.
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
#define PASS_NAME "tosa-lower-complex-types"
......@@ -121,8 +124,6 @@ void LowerComplexTypes::runOnOperation() {
ComplexTypeConverter converter;
ConversionTarget target(getContext());
target.addIllegalOp<mlir::UnrealizedConversionCastOp>();
// Operations are legal if they don't contain any illegal type.
target.markUnknownOpDynamicallyLegal([](Operation* op) {
if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
......@@ -152,6 +153,12 @@ void LowerComplexTypes::runOnOperation() {
if (failed(applyFullConversion(func, target, std::move(patterns)))) {
signalPassFailure();
}
// We need to run folders post rewrite to cleanup conversion casts.
RewritePatternSet emptyRewriters(ctx);
if (failed(applyPatternsAndFoldGreedily(func, std::move(emptyRewriters)))) {
signalPassFailure();
}
}
} // anonymous namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册