提交 f5ee7354 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Updated for changes in LLVM 7a45aeacf3a2.

- The deprecated CreateCall(Value*, ...) IRBuilder API has been removed.
- Renamed applyPatternsGreedily to applyPatternsAndFoldGreedily in MLIR.
- Update MLIR users after adding support for optional operands/results to ODS (upstream aba1acc89c653b2cc08cccfb754ff16994a05332)
- Other updates to BUILD files for upstream changes.

PiperOrigin-RevId: 306177884
Change-Id: Idae1009ba89caf296758748ab7aa57815d946a0c
上级 700ff489
......@@ -496,7 +496,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (value.isVariableLength()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
......@@ -504,7 +504,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (value.isVariableLength()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}
......
......@@ -146,7 +146,7 @@ void LegalizeTFToQuant::runOnFunction() {
auto func = getFunction();
auto *ctx = func.getContext();
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
......
......@@ -30,7 +30,7 @@ void IdentifyDilatedConvPass::runOnFunction() {
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
&getContext());
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
......
......@@ -711,7 +711,7 @@ void Optimize::runOnFunction() {
TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
FuseFullyConnectedAndMul>(ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
// Fuse the binary ops with the following ops.
patterns.insert<
......@@ -719,7 +719,7 @@ void Optimize::runOnFunction() {
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
......
......@@ -187,7 +187,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
ModuleOp module = getOperation();
applyPatternsGreedily(module, patterns);
applyPatternsAndFoldGreedily(module, patterns);
// Erase inlined functions that don't have any references.
//
......
......@@ -125,7 +125,7 @@ void PostQuantizePass::runOnFunction() {
auto func = getFunction();
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
if (!emit_quant_adaptor_ops_) {
RemoveQuantizationAdaptorOps(getFunction());
......
......@@ -267,7 +267,7 @@ void PrepareQuantizePass::runOnFunction() {
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
}
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
SanityCheckAndAdjustment(func);
......
......@@ -619,8 +619,8 @@ void PrepareTFPass::runOnFunction() {
// This pattern was intented to uses TFL QDQs to preserve the quantization
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the
// TF FakeQuant ops by the constant folding.
// first `applyPatternsAndFoldGreedily` method, which would otherwise removes
// the TF FakeQuant ops by the constant folding.
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
// This pattern will try to identify and optimize for dilated convolution.
......@@ -634,7 +634,7 @@ void PrepareTFPass::runOnFunction() {
// This will allow optimizing any TF_Mul->TF_Conv in the graph
// and any expanded from FusedBatchNorm. We need to do this
// before converting TF_Conv to TFL_Conv
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
// Load the generated pattern again, so new quantization pass-through
// will be applied.
......@@ -646,7 +646,7 @@ void PrepareTFPass::runOnFunction() {
}
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
......
......@@ -88,7 +88,7 @@ void QuantizePass::runOnFunction() {
TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<TFLFullQuantization>(
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
......
......@@ -55,7 +55,7 @@ void BatchMatMulToEinsumPass::runOnFunction() {
patterns.insert<ConvertTFBatchMatMulToEinsumOp<TF::BatchMatMulOp>,
ConvertTFBatchMatMulToEinsumOp<TF::BatchMatMulV2Op>>(
&getContext());
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
......
......@@ -45,7 +45,7 @@ struct DecomposeResourceOps
OwningRewritePatternList patterns;
mlir::TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
applyPatternsGreedily(getFunction(), patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
......
......@@ -364,7 +364,7 @@ void TransformEinsumPass::runOnFunction() {
auto func = getFunction();
patterns.insert<ConvertTFEinsumOp>(&getContext());
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
static PassRegistration<TransformEinsumPass> pass(
......
......@@ -118,7 +118,7 @@ void GpuOpFusionPass::runOnFunction() {
FuncOp func = getFunction();
OwningRewritePatternList patterns;
patterns.insert<ReluToFusedBatchNorm>(&getContext());
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
......
......@@ -29,7 +29,7 @@ struct LowerTF : public PassWrapper<LowerTF, FunctionPass> {
OwningRewritePatternList patterns;
mlir::TF::PopulateLoweringTFPatterns(&getContext(), &patterns);
applyPatternsGreedily(getFunction(), patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
......
......@@ -38,7 +38,7 @@ struct TFOptimizePass : public PassWrapper<TFOptimizePass, FunctionPass> {
OwningRewritePatternList patterns;
auto func = getFunction();
populateWithGenerated(&getContext(), &patterns);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
};
......
......@@ -55,7 +55,7 @@ void UnrollBatchMatMulPass::runOnFunction() {
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
......
......@@ -81,7 +81,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
// Emit an argument for an operand.
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
// Handle a non-variadic operand.
if (!operand_cst->isVariadic()) {
if (!operand_cst->isVariableLength()) {
os << " auto xla_arg_" << index
<< " = value_map[*xla_op.getODSOperands(" << operand_number++
<< ").begin()];\n";
......@@ -108,7 +108,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
// If all operands are variadic, then pass the builder explicitly to xla
// client API call
if (op.getNumOperands() == op.getNumVariadicOperands()) {
if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
os << "lowering_context.builder";
if (op.getNumArgs() != 0) os << ", ";
}
......
......@@ -198,7 +198,7 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
void LegalizeToStandard::runOnFunction() {
OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
applyPatternsGreedily(getFunction(), patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<LegalizeToStandard> legalize_pass(
......
......@@ -87,7 +87,7 @@ struct LhloLegalizeToAffine
OwningRewritePatternList patterns;
auto func = getFunction();
populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
};
......
......@@ -71,7 +71,7 @@ void LowerComplex::runOnFunction() {
OwningRewritePatternList patterns;
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
applyPatternsGreedily(getFunction(), patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<LowerComplex> pass(
......
......@@ -178,7 +178,7 @@ struct LegalizeGeneralDot
OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
&getContext());
applyPatternsGreedily(getFunction(), patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
......
......@@ -32,7 +32,7 @@ struct TestUnfuseBatchNormPass
void runOnOperation() override {
OwningRewritePatternList patterns;
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
applyPatternsGreedily(getOperation(), patterns);
applyPatternsAndFoldGreedily(getOperation(), patterns);
}
};
......
......@@ -79,11 +79,31 @@ class IrBuilderMixin {
return mixin_builder()->CreateBr(std::forward<Args>(args)...);
}
llvm::CallInst* Call(llvm::FunctionCallee func_callee,
llvm::ArrayRef<llvm::Value*> args = llvm::None,
const llvm::Twine& name = "",
llvm::MDNode* fp_math_tag = nullptr) {
return mixin_builder()->CreateCall(func_callee, args, name, fp_math_tag);
}
llvm::CallInst* Call(llvm::FunctionType* func_type, llvm::Value* callee,
llvm::ArrayRef<llvm::Value*> args = llvm::None,
const llvm::Twine& name = "",
llvm::MDNode* fp_math_tag = nullptr) {
return mixin_builder()->CreateCall(func_type, callee, args, name,
fp_math_tag);
}
// DEPRECATED. LLVM is removing getPointerElementType, so calls to this must
// be transitioned to one of the other overloads.
llvm::CallInst* Call(llvm::Value* callee,
llvm::ArrayRef<llvm::Value*> args = llvm::None,
const llvm::Twine& name = "",
llvm::MDNode* fp_math_tag = nullptr) {
return mixin_builder()->CreateCall(callee, args, name, fp_math_tag);
return mixin_builder()->CreateCall(
llvm::cast<llvm::FunctionType>(
callee->getType()->getPointerElementType()),
callee, args, name, fp_math_tag);
}
template <class... Args>
......
......@@ -2666,6 +2666,20 @@ cc_binary(
],
)
cc_binary(
name = "mlir-linalg-ods-gen",
srcs = glob([
"tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp",
]),
deps = [
":IR",
":Support",
"@llvm-project//llvm:config",
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
],
)
## OpenMP dialect
gentbl(
name = "OpenMPOpsIncGen",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册