提交 098eebff 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Integrate LLVM at llvm/llvm-project@e75bc5c791e0

Updates LLVM usage to match
[e75bc5c791e0](https://github.com/llvm/llvm-project/commit/e75bc5c791e0)

PiperOrigin-RevId: 327449456
Change-Id: I847530d2d7325bd7bfeef29a3899d8cbcb3257e0
上级 4e0438a9
......@@ -56,19 +56,9 @@ class MhloDialect : public Dialect {
void printType(Type type, DialectAsmPrinter &os) const override;
};
namespace HLOTypes {
enum Kind {
Token = Type::FIRST_XLA_HLO_TYPE,
};
} // namespace HLOTypes
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
public:
using Base::Base;
static TokenType get(MLIRContext *context) {
return Base::get(context, HLOTypes::Token);
}
};
// Shape derivation function that computes the shape of the result based on
......
......@@ -706,10 +706,8 @@ struct ConvertTFBroadcastTo : public RewritePattern {
shape_type.getDimSize(0) <= 5)))
return failure();
if (!((element_type.getKind() == mlir::StandardTypes::F32) ||
(element_type.getKind() == mlir::StandardTypes::BF16) ||
(element_type.getKind() == mlir::StandardTypes::Integer &&
element_type.cast<mlir::IntegerType>().getWidth() == 32)))
if (!(element_type.isa<BFloat16Type, Float32Type>() ||
element_type.isInteger(32)))
return failure();
auto status_or_const_op =
......
......@@ -20,11 +20,6 @@ limitations under the License.
void init_types(py::module& m) {
// Type
py::class_<mlir::Type> Type(m, "Type");
Type.def("getKind", &mlir::Type::getKind);
// Type Enums
py::enum_<mlir::StandardTypes::Kind>(Type, "StandardTypes_Kind")
.value("BF16", mlir::StandardTypes::BF16);
// Type Sub-classes
py::class_<mlir::FunctionType, mlir::Type>(m, "FunctionType")
......
......@@ -74,12 +74,9 @@ struct FuncAttrStorage : public AttributeStorage {
// Get or create a shape attribute.
ShapeAttr ShapeAttr::get(mlir::MLIRContext* context,
llvm::Optional<ArrayRef<int64_t>> shape) {
if (shape)
return Base::get(context, AttrKind::SHAPE, *shape,
/*unranked=*/false);
if (shape) return Base::get(context, *shape, /*unranked=*/false);
return Base::get(context, AttrKind::SHAPE, ArrayRef<int64_t>(),
/*unranked=*/true);
return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
}
llvm::Optional<ArrayRef<int64_t>> ShapeAttr::getValue() const {
......@@ -112,12 +109,12 @@ bool ShapeAttr::hasStaticShape() const {
FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name,
DictionaryAttr attr) {
auto symbol = SymbolRefAttr::get(name, context);
return Base::get(context, AttrKind::FUNC, symbol, attr);
return Base::get(context, symbol, attr);
}
FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol,
DictionaryAttr attr) {
return Base::get(context, AttrKind::FUNC, symbol, attr);
return Base::get(context, symbol, attr);
}
SymbolRefAttr FuncAttr::GetName() const {
......
......@@ -24,19 +24,6 @@ limitations under the License.
namespace mlir {
namespace TF {
namespace AttrKind {
// List of supported custom TensorFlow Attribute kinds, necessary for
// isa/dyn_cast.
enum Kind {
FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR,
SHAPE = FIRST_USED_TENSORFLOW_ATTR,
FUNC,
LAST_USED_TENSORFLOW_ATTR,
};
} // namespace AttrKind
namespace detail {
struct ShapeAttrStorage;
......
......@@ -45,31 +45,16 @@ class TensorFlowExecutorDialect : public Dialect {
void printType(Type type, DialectAsmPrinter &os) const override;
};
namespace TFTypes {
enum Kind {
Control = Type::FIRST_TENSORFLOW_EXECUTOR_TYPE,
Token,
};
} // namespace TFTypes
// The Control type is a token-like value that models control dependencies from
// TensorFlow graphs.
class ControlType : public Type::TypeBase<ControlType, Type, TypeStorage> {
public:
using Base::Base;
static ControlType get(MLIRContext *context) {
return Base::get(context, TFTypes::Control);
}
};
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
public:
using Base::Base;
static TokenType get(MLIRContext *context) {
return Base::get(context, TFTypes::Token);
}
};
// Declares the operations for this dialect using the generated header.
......
......@@ -358,16 +358,12 @@ Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
void TensorFlowDialect::printAttribute(Attribute attr,
DialectAsmPrinter &os) const {
switch (attr.getKind()) {
case AttrKind::SHAPE:
PrintShapeAttr(attr.cast<ShapeAttr>(), os);
break;
case AttrKind::FUNC:
PrintFuncAttr(attr.cast<FuncAttr>(), os);
break;
default:
llvm_unreachable("unexpected tensorflow attribute kind");
}
if (auto shape_attr = attr.dyn_cast<ShapeAttr>())
PrintShapeAttr(shape_attr, os);
else if (auto func_attr = attr.dyn_cast<FuncAttr>())
PrintFuncAttr(func_attr, os);
else
llvm_unreachable("unexpected tensorflow attribute type");
}
// Parses a type registered to this dialect.
......@@ -376,32 +372,18 @@ Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
if (parser.parseKeyword(&data)) return Type();
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
auto typeKind = llvm::StringSwitch<unsigned>(data)
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
.Case(name, TensorFlowTypes::enumerant)
if (data == name) return tftype##Type::get(getContext());
// Custom TensorFlow types are handled separately at the end as they do partial
// match.
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
.StartsWith("resource", TensorFlowTypes::RESOURCE)
.StartsWith("variant", TensorFlowTypes::VARIANT)
.Default(0);
switch (typeKind) {
default:
return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
case TensorFlowTypes::enumerant: \
return tftype##Type::get(getContext());
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
case TensorFlowTypes::RESOURCE:
return ParseResourceType(parser, loc);
case TensorFlowTypes::VARIANT:
return ParseVariantType(parser, loc);
}
if (data.startswith("resource")) return ParseResourceType(parser, loc);
if (data.startswith("variant")) return ParseVariantType(parser, loc);
return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
}
// Prints a type registered to this dialect.
......
......@@ -67,16 +67,6 @@ using ResultShapeRange = iterator_range<ResultShapeIterator>;
// TensorFlow types
//===----------------------------------------------------------------------===//
namespace TensorFlowTypes {
// List of supported TensorFlowType kinds, necessary for isa/dyn_cast.
enum Kind {
FIRST_USED_TENSORFLOW_TYPE = Type::FIRST_TENSORFLOW_TYPE,
#define HANDLE_TF_TYPE(tftype, enumerant, name) enumerant,
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
LAST_USED_TENSORFLOW_TYPE,
};
} // namespace TensorFlowTypes
// The base class in the TensorFlow type hierarchy.
class TensorFlowType : public Type {
public:
......@@ -102,10 +92,7 @@ static inline bool IsValidTFTensorType(Type type) {
namespace detail {
// Common implementation of TensorFlow types. The template argument indicates
// the concrete derived class per CRTP. Concrete classes must implement the
// following:
// - `static unsigned getTypeKind()` that returns the (fixed) kind of the
// type.
// the concrete derived class per CRTP.
template <typename Derived>
class TensorFlowTypeImpl
: public Type::TypeBase<Derived, TensorFlowType, TypeStorage> {
......@@ -113,11 +100,6 @@ class TensorFlowTypeImpl
using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
using TFBase = TensorFlowTypeImpl<Derived>;
using Base::Base;
// Get the unique'ed type in the given context.
static Derived get(MLIRContext* context) {
return Base::get(context, Derived::getTypeKind());
}
};
} // namespace detail
......@@ -173,7 +155,6 @@ static inline Type GetElementTypeOrSelfResolveRef(Type type) {
class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> { \
public: \
using TFBase::TFBase; \
static unsigned getTypeKind() { return TensorFlowTypes::enumerant; } \
};
// Custom TensorFlow types are defined separately.
......@@ -211,8 +192,6 @@ class TypeWithSubtypeStorage : public TypeStorage {
// opaque and their interpretation depends on the actual underlying type.
// The template argument indicates the concrete derived class per CRTP. Concrete
// classes must implement the following:
// - `static unsigned getTypeKind()` that returns the (fixed) kind of the
// type.
// - `static std::string getTypeName()` that returns the name of the type for
// verification logging.
template <typename Derived>
......@@ -224,12 +203,12 @@ class TypeWithSubtypeImpl
using Base::Base;
static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context) {
return Base::get(context, Derived::getTypeKind(), subtypes);
return Base::get(context, subtypes);
}
static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
Location loc) {
return Base::getChecked(loc, Derived::getTypeKind(), subtypes);
return Base::getChecked(loc, subtypes);
}
static Derived get(MLIRContext* context) { return get({}, context); }
......@@ -279,7 +258,6 @@ static inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) {
class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType> {
public:
using TFBase::TFBase;
static unsigned getTypeKind() { return TensorFlowTypes::RESOURCE; }
static std::string getTypeName() { return "ResourceType"; }
};
......@@ -291,7 +269,6 @@ class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType> {
class VariantType : public detail::TypeWithSubtypeImpl<VariantType> {
public:
using TFBase::TFBase;
static unsigned getTypeKind() { return TensorFlowTypes::VARIANT; }
static std::string getTypeName() { return "VariantType"; }
};
......
......@@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
......@@ -368,65 +369,36 @@ Status ConvertAttributes(
name = mangling_util::DemangleAttributeName(name);
}
AttrValue value;
switch (attr.getKind()) {
case mlir::StandardAttributes::SymbolRef: {
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::FlatSymbolRefAttr>(), &value));
func_call_attrs[string(name)] = value;
continue;
}
case mlir::StandardAttributes::Integer:
if (auto boolAttr = attr.dyn_cast<mlir::BoolAttr>()) {
TF_RETURN_IF_ERROR(ConvertAttribute(boolAttr, &value));
} else {
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::IntegerAttr>(), &value));
}
break;
case mlir::StandardAttributes::Float:
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::FloatAttr>(), &value));
break;
case mlir::StandardAttributes::String:
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::StringAttr>(), &value));
break;
case mlir::StandardAttributes::Array:
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::ArrayAttr>(), &value));
break;
case mlir::StandardAttributes::DenseIntOrFPElements:
case mlir::StandardAttributes::DenseStringElements:
case mlir::StandardAttributes::OpaqueElements:
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::ElementsAttr>(), &value));
break;
case mlir::StandardAttributes::Type:
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::TypeAttr>(), &value));
break;
case mlir::StandardAttributes::Unit:
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::UnitAttr>(), &value));
break;
case static_cast<unsigned>(mlir::TF::AttrKind::SHAPE):
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::TF::ShapeAttr>(), &value));
break;
case static_cast<unsigned>(mlir::TF::AttrKind::FUNC): {
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::TF::FuncAttr>(), &value));
func_call_attrs[string(name)] = value;
continue;
}
// AffineMap kind is not implemented.
case mlir::StandardAttributes::AffineMap:
return errors::Unimplemented("AffineMap attribute (needed for '",
name_strref, "') unimplemented");
default:
return errors::Unimplemented("Unhandled attribute kind for attribute '",
name_strref, '\'');
if (auto symbol_ref = attr.dyn_cast<mlir::SymbolRefAttr>()) {
TF_RETURN_IF_ERROR(
ConvertAttribute(symbol_ref.cast<mlir::FlatSymbolRefAttr>(), &value));
func_call_attrs[string(name)] = value;
continue;
}
if (auto func_attr = attr.dyn_cast<mlir::TF::FuncAttr>()) {
TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, &value));
func_call_attrs[string(name)] = value;
continue;
}
if (attr.isa<mlir::AffineMapAttr>()) {
// AffineMapAttr is not implemented.
return errors::Unimplemented("AffineMap attribute (needed for '",
name_strref, "') unimplemented");
}
TF_RETURN_IF_ERROR(
llvm::TypeSwitch<mlir::Attribute, Status>(attr)
.Case<mlir::BoolAttr, mlir::IntegerAttr, mlir::FloatAttr,
mlir::StringAttr, mlir::ArrayAttr, mlir::ElementsAttr,
mlir::TypeAttr, mlir::UnitAttr, mlir::TF::ShapeAttr>(
[&](auto derived_attr) {
return ConvertAttribute(derived_attr, &value);
})
.Default([&](mlir::Attribute) {
return errors::Unimplemented(
"Unhandled attribute kind for attribute '", name_strref,
'\'');
}));
// According to the NodeDef proto definition, an attribute name from the
// input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
// the attribute from MLIR, it is treated as an attribute from function
......
......@@ -30,22 +30,12 @@ namespace mlir {
namespace kernel_gen {
namespace tf_framework {
namespace TFFrameworkTypes {
enum Kind {
OpKernelContextType = Type::FIRST_TF_FRAMEWORK_TYPE,
};
} // namespace TFFrameworkTypes
/// OpKernelContextType corresponds to C++ class OpKernelContext defined in
/// tensorflow/core/framework/op_kernel.h
class OpKernelContextType
: public Type::TypeBase<OpKernelContextType, Type, TypeStorage> {
public:
using Base::Base;
static OpKernelContextType get(MLIRContext *context) {
return Base::get(context, TFFrameworkTypes::Kind::OpKernelContextType);
}
};
#define GET_OP_CLASSES
......
......@@ -699,8 +699,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
)
# Check out LLVM and MLIR from llvm-project.
LLVM_COMMIT = "bf36e902953a4bf8ac0aae5a498445951fbc3882"
LLVM_SHA256 = "ae3f8eeb10b0b3f01196339b4a6083385b625f2feb422d965037375a9659afc9"
LLVM_COMMIT = "e75bc5c791e0e8dbe79f7453e55af9e8d03c9cc0"
LLVM_SHA256 = "9c22f59d50853329cd0105ecb95256ad345313372ddda593030cd81b7c72e657"
LLVM_URLS = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
......
......@@ -24,14 +24,6 @@ exports_files([
"run_lit.sh",
])
cc_library(
name = "DialectSymbolRegistry",
# strip_include_prefix does not apply to textual_hdrs.
hdrs = ["include/mlir/IR/DialectSymbolRegistry.def"],
strip_include_prefix = "include/mlir/IR",
textual_hdrs = ["include/mlir/IR/DialectSymbolRegistry.def"],
)
[
gentbl(
name = name + "IncGen",
......@@ -75,7 +67,6 @@ cc_library(
includes = ["include"],
deps = [
":CallOpInterfacesIncGen",
":DialectSymbolRegistry",
":InferTypeOpInterfaceIncGen",
":OpAsmInterfaceIncGen",
":RegionKindInterfaceIncGen",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册