提交 cc773ec8 编写于 作者: N nihuini

adapt mlir api changes

上级 5d40a42e
......@@ -18,7 +18,6 @@
#include <set>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
......
......@@ -14,8 +14,8 @@
#include "ncnn_dialect.h"
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/StandardTypes.h>
namespace mlir {
......
......@@ -16,7 +16,6 @@
#define NCNN_DIALECT_H
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace mlir {
......
......@@ -19,9 +19,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
namespace mlir {
namespace TF {
......
......@@ -19,7 +19,6 @@
#include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/MLIRContext.h>
......@@ -28,7 +27,6 @@
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
#include <mlir/IR/Value.h>
......
......@@ -16,9 +16,10 @@
#define TF_DIALECT_H
#include <mlir/Dialect/Traits.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
......
......@@ -89,6 +89,16 @@ struct DatasetIterator : ::mlir::SideEffects::Resource::Base<DatasetIterator>
}
};
// Special resource type to track TPU Embedding specific ops, which must execute
// but do not have side effects with one another or with resource variable ops.
struct TPUEmbedding : ::mlir::SideEffects::Resource::Base<TPUEmbedding>
{
StringRef getName() final
{
return "TPUEmbedding";
}
};
} // namespace ResourceEffects
} // namespace TF
} // namespace mlir
......
......@@ -18,8 +18,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
......@@ -71,6 +71,46 @@ public:
}
};
namespace detail {
inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
Operation* op)
{
Type element_type;
if (op->getNumResults() > 0)
{
element_type = mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
}
else if (op->getNumOperands() > 0)
{
element_type = mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
}
else
{
// Nothing to check.
return success();
}
// Verify that all result element types are compatible to `element_type`.
for (const auto& result_type : op->getResultTypes())
{
if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type)
{
return op->emitOpError(
"requires compatible element types for all operands and results");
}
}
// Verify that all operand element types are compatible to `element_type`.
for (const auto& operand_type : op->getOperandTypes())
{
if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) != element_type)
{
return op->emitOpError(
"requires compatible element types for all operands and results");
}
}
return success();
}
} // namespace detail
// Verifies that op has the same operand and result element types (or type
// itself, if scalar) after resolving reference types (i.e., after converting
// reference types to their corresponding TensorFlow or standard types).
......@@ -82,39 +122,22 @@ class SameOperandsAndResultElementTypeResolveRef
public:
static LogicalResult verifyTrait(Operation* op)
{
Type element_type;
if (op->getNumResults() > 0)
{
element_type = mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
}
else if (op->getNumOperands() > 0)
{
element_type = mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
}
else
{
// Nothing to check.
return success();
}
// Verify that all result element types are compatible to `element_type`.
for (const auto& result_type : op->getResultTypes())
{
if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type)
{
return op->emitOpError(
"requires compatible element types for all operands and results");
}
}
// Verify that all operand element types are compatible to `element_type`.
for (const auto& operand_type : op->getOperandTypes())
{
if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) != element_type)
{
return op->emitOpError(
"requires compatible element types for all operands and results");
}
}
return success();
return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
}
};
// Verifies that op has the same operand and result types after resolving
// reference types (i.e., after converting reference types to their
// corresponding TensorFlow or standard types).
template<typename ConcreteType>
class SameOperandsAndResultTypeResolveRef
: public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef>
{
public:
static LogicalResult verifyTrait(Operation* op)
{
if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
}
};
......
......@@ -17,8 +17,8 @@ limitations under the License.
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
namespace {
......@@ -185,19 +185,19 @@ Type TensorFlowRefType::RemoveRef()
if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
if (isa<BoolRefType>()) return mlir::IntegerType::get(1, ctx);
if (isa<Int8RefType>()) return mlir::IntegerType::get(8, ctx);
if (isa<Int16RefType>()) return mlir::IntegerType::get(16, ctx);
if (isa<Int32RefType>()) return mlir::IntegerType::get(32, ctx);
if (isa<Int64RefType>()) return mlir::IntegerType::get(64, ctx);
if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
if (isa<Uint8RefType>())
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
if (isa<Uint16RefType>())
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
if (isa<Uint32RefType>())
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
if (isa<Uint64RefType>())
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
if (isa<Complex64RefType>())
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
if (isa<Complex128RefType>())
......
......@@ -18,10 +18,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册