提交 3187d364 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Move ConvertTFQuantOpsToMHLO pass definition to passes.td to be consistent with other passes

PiperOrigin-RevId: 549654509
上级 30e983f6
......@@ -58,6 +58,9 @@ namespace mlir {
namespace stablehlo {
namespace {
#define GEN_PASS_DEF_CONVERTTFQUANTOPSTOMHLO
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc"
FailureOr<IntegerType> GetStorageType(Operation *op,
Type original_output_element_type,
PatternRewriter &rewriter) {
......@@ -638,37 +641,17 @@ class ConvertUniformQuantizedClipByValueOp
}
};
class ConvertTFQuantOpsToMHLOPass
: public PassWrapper<ConvertTFQuantOpsToMHLOPass,
OperationPass<func::FuncOp>> {
class ConvertTFQuantOpsToMHLO
: public impl::ConvertTFQuantOpsToMHLOBase<ConvertTFQuantOpsToMHLO> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertTFQuantOpsToMHLOPass)
StringRef getArgument() const final {
// This is the argument used to refer to the pass in
// the textual format (on the commandline for example).
return "quant-convert-tf-quant-ops-to-mhlo";
}
StringRef getDescription() const final {
// This is a brief description of the pass.
return "Convert TF Quant ops to MHLO quantization";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TF::TensorFlowDialect>();
registry.insert<mhlo::MhloDialect>();
registry.insert<chlo::ChloDialect>();
registry.insert<tf_type::TFTypeDialect>();
registry.insert<quant::QuantizationDialect>();
}
ConvertTFQuantOpsToMHLO() = default;
ConvertTFQuantOpsToMHLO(const ConvertTFQuantOpsToMHLO &) = default;
// Performs conversion of MHLO quant ops to primitive ops.
void runOnOperation() override;
};
static PassRegistration<ConvertTFQuantOpsToMHLOPass> pass;
void ConvertTFQuantOpsToMHLOPass::runOnOperation() {
void ConvertTFQuantOpsToMHLO::runOnOperation() {
MLIRContext *ctx = &getContext();
func::FuncOp func = getOperation();
ConversionTarget target(*ctx);
......@@ -703,7 +686,7 @@ void PopulateLegalizeTfQuantizationPatterns(MLIRContext *context,
std::unique_ptr<OperationPass<func::FuncOp>>
CreateConvertTFQuantOpsToMHLOPass() {
return std::make_unique<ConvertTFQuantOpsToMHLOPass>();
return std::make_unique<ConvertTFQuantOpsToMHLO>();
}
} // namespace stablehlo
......
......@@ -30,10 +30,6 @@ namespace stablehlo {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertMHLOQuantToIntPass(
bool legalize_chlo = true);
#define GEN_PASS_REGISTRATION
#define GEN_PASS_DECL_CONVERTMHLOQUANTTOINT
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc"
// Creates an instance of the ConvertTFQuantOpsToMHLOPass pass, which will
// convert TF uniform quantized ops to the corresponding quantized MHLO ops.
std::unique_ptr<OperationPass<func::FuncOp>>
......@@ -44,6 +40,10 @@ CreateConvertTFQuantOpsToMHLOPass();
void PopulateLegalizeTfQuantizationPatterns(MLIRContext *context,
RewritePatternSet *patterns);
#define GEN_PASS_REGISTRATION
#define GEN_PASS_DECL_CONVERTMHLOQUANTTOINT
#define GEN_PASS_DECL_CONVERTTFQUANTOPSTOMHLO
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc"
} // namespace stablehlo
} // namespace mlir
......
......@@ -35,3 +35,16 @@ def ConvertMHLOQuantToInt : Pass<"convert-mhlo-quant-to-int", "mlir::func::FuncO
"shape::ShapeDialect",
"func::FuncDialect"];
}
def ConvertTFQuantOpsToMHLO : Pass<"quant-convert-tf-quant-ops-to-mhlo", "mlir::func::FuncOp"> {
let summary = "Convert TF Quant ops to MHLO quantizated ops.";
let description = [{
Convert TF Quant ops to MHLO quant ops.
}];
let constructor = "mlir::stablehlo::CreateConvertTFQuantOpsToMHLOPass()";
let dependentDialects = ["TF::TensorFlowDialect", "chlo::ChloDialect",
"mhlo::MhloDialect", "tf_type::TFTypeDialect",
"quant::QuantizationDialect"];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册