提交 c827dfe8 编写于 作者: S Stephan Herhut 提交者: TensorFlower Gardener

Add a basic same_shape inference and transform to kernel generator.

This adds back functionality to the unranked case that was already present (albeit manual) in the ranked case.

PiperOrigin-RevId: 339889861
Change-Id: I6426ac5553c75bebd829135104d1b7802ff2a254
上级 556c1a4a
// RUN: kernel-gen-opt %s -allow-unregistered-dialect -propagate-tf-abi-knowledge-to-kernels | FileCheck %s
// RUN: kernel-gen-opt %s -allow-unregistered-dialect -propagate-tf-abi-knowledge-to-kernels | FileCheck %s --check-prefixes=CHECK,ABI
// RUN: kernel-gen-opt %s -allow-unregistered-dialect -propagate-shape-knowledge-to-kernels | FileCheck %s --check-prefixes=CHECK,SHAPE
// The input is taken from what is actually used in kernel generator lowering
// for unary operations. This could be minimized but then we would not be
......@@ -59,18 +60,21 @@ module attributes {gpu.container_module} {
// CHECK-LABEL: @__nv_fabsf
llvm.func @__nv_fabsf(!llvm.float) -> !llvm.float
// CHECK-LABEL: @abs_kernel
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<float>, %[[ARG1:.*]]: !llvm.ptr<float> {llvm.align = 16 : index},
// CHECK-SAME: %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.ptr<float>, %[[ARG6:.*]]: !llvm.ptr<float> {llvm.align = 16 : index, llvm.noalias = true},
// CHECK-SAME: %[[ARG7:.*]]: !llvm.i64, %[[ARG8:.*]]: !llvm.i64, %[[ARG9:.*]]: !llvm.i64
// ABI-SAME: %[[ARG0:.*]]: !llvm.ptr<float>, %[[ARG1:.*]]: !llvm.ptr<float> {llvm.align = 16 : index},
// ABI-SAME: %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.ptr<float>, %[[ARG6:.*]]: !llvm.ptr<float> {llvm.align = 16 : index, llvm.noalias = true},
// ABI-SAME: %[[ARG7:.*]]: !llvm.i64, %[[ARG8:.*]]: !llvm.i64, %[[ARG9:.*]]: !llvm.i64
// SHAPE-SAME: %[[ARG0:.*]]: !llvm.ptr<float>, %[[ARG1:.*]]: !llvm.ptr<float>, %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.ptr<float>, %[[ARG6:.*]]: !llvm.ptr<float>, %[[ARG7:.*]]: !llvm.i64, %[[ARG8:.*]]: !llvm.i64, %[[ARG9:.*]]: !llvm.i64
llvm.func @abs_kernel(%arg0: !llvm.ptr<float>, %arg1: !llvm.ptr<float>, %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.ptr<float>, %arg6: !llvm.ptr<float>, %arg7: !llvm.i64, %arg8: !llvm.i64, %arg9: !llvm.i64) attributes {gpu.kernel} {
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index)
// ABI: %[[ZERO:.*]] = llvm.mlir.constant(0 : index)
// CHECK: llvm.mlir.undef
%0 = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ARG1]]
// ABI-NEXT: llvm.insertvalue %[[ARG1]]
// SHAPE-NEXT: llvm.insertvalue %[[ARG0]]
%1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ARG1]]
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ZERO]]
// ABI-NEXT: llvm.insertvalue %[[ZERO]]
// SHAPE-NEXT: llvm.insertvalue %[[ARG2]]
%3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ARG3]]
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
......@@ -78,15 +82,19 @@ module attributes {gpu.container_module} {
%5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.mlir.undef
%6 = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ARG6]]
// ABI-NEXT: llvm.insertvalue %[[ARG6]]
// SHAPE-NEXT: llvm.insertvalue %[[ARG5]]
%7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ARG6]]
%8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ZERO]]
// ABI-NEXT: llvm.insertvalue %[[ZERO]]
// SHAPE-NEXT: llvm.insertvalue %[[ARG7]]
%9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ARG8]]
// ABI-NEXT: llvm.insertvalue %[[ARG8]]
// SHAPE-NEXT: llvm.insertvalue %[[ARG3]]
%10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ARG9]]
// ABI-NEXT: llvm.insertvalue %[[ARG9]]
// SHAPE-NEXT: llvm.insertvalue %[[ARG4]]
%11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
%12 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
%13 = llvm.sext %12 : !llvm.i32 to !llvm.i64
......
......@@ -79,6 +79,7 @@ cc_library(
"materialize_broadcasts_pass.cc",
"parallel_loops_to_sequential.cc",
"propagate_tf_abi_knowledge_pass.cc",
"same_shape_propagation.cc",
"shape_to_descriptors_pass.cc",
"tensorflow_abi_knowledge_propagation.cc",
"tf_kernel_to_llvm_pass.cc",
......
......@@ -75,6 +75,9 @@ std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
// Pass to propagate tensorflow runtime ABI knowledge across kernel boundaries.
std::unique_ptr<FunctionPass> CreatePropagateTfAbiKnowledgeToKernels();
// Pass to propagate shape equalities across kernel boundaries.
std::unique_ptr<FunctionPass> CreatePropagateShapeKnowledgeToKernels();
} // namespace transforms
#define GEN_PASS_REGISTRATION
......
......@@ -86,4 +86,11 @@ def PropagateTfAbiKnowledgeToKernels
let summary = "Pass to propagate tensorflow ABI knowledge to kernels";
let constructor = "transforms::CreatePropagateTfAbiKnowledgeToKernels()";
}
def PropagateShapeKnowledgeToKernels
: FunctionPass<"propagate-shape-knowledge-to-kernels"> {
let summary = "Pass to propagate shape information into kernels";
let constructor = "transforms::CreatePropagateShapeKnowledgeToKernels()";
}
#endif // TF_KERNEL_GEN_PASSES
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains the analysis and transformation to rewrite kernel
// functions such that they use a single set of arguments for the strides and
// sizes of operands with equal shapes.
#include <memory>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
namespace {
using mlir::ArrayRef;
using mlir::SmallVector;
using mlir::Value;
/// Represents a value or constant. Used to unify operands for operations that
/// take both ssa values and attributes.
struct ValueOrConst {
explicit ValueOrConst(Value v) : value_or_constant(v), is_constant(false) {}
explicit ValueOrConst(int64_t c) : value_or_constant(c), is_constant(true) {}
Value value() const {
assert(!is_constant);
return value_or_constant.value;
}
int64_t constant() const {
assert(is_constant);
return value_or_constant.constant;
}
bool isConstant() const { return is_constant; }
private:
union ValueOrConstStorage {
explicit ValueOrConstStorage(Value v) : value(v) {}
explicit ValueOrConstStorage(size_t c) : constant(c) {}
Value value;
int64_t constant;
} value_or_constant;
bool is_constant;
};
llvm::hash_code hash_value(ValueOrConst value) {
return value.isConstant() ? static_cast<llvm::hash_code>(value.constant())
: mlir::hash_value(value.value());
}
bool operator==(ValueOrConst lhs, ValueOrConst rhs) {
if (lhs.isConstant()) {
return rhs.isConstant() && lhs.constant() == rhs.constant();
} else {
return !rhs.isConstant() && lhs.value() == rhs.value();
}
}
/// Represents a shape, as either a single ssa value that represents the entire
/// shape vector or as a vector of ssa values representing scalars.
struct ShapeValue {
explicit ShapeValue(Value vector)
: shape({ValueOrConst{vector}}), is_vector(true) {}
explicit ShapeValue(ValueOrConst vector) : shape({vector}), is_vector(true) {
assert(!vector.isConstant());
}
template <typename T>
explicit ShapeValue(T values)
: shape(values.begin(), values.end()), is_vector(false) {}
ValueOrConst vector() const {
assert(is_vector);
return shape.front();
}
ArrayRef<ValueOrConst> scalars() const {
assert(!is_vector);
return llvm::makeArrayRef(shape);
}
bool isVector() const { return is_vector; }
private:
SmallVector<ValueOrConst, 4> shape;
bool is_vector;
};
llvm::hash_code hash_value(ShapeValue shape) {
return shape.isVector() ? hash_value(shape.vector())
: hash_value(shape.scalars());
}
bool operator==(ShapeValue lhs, ShapeValue rhs) {
if (lhs.isVector()) {
return rhs.isVector() && lhs.vector() == rhs.vector();
} else {
return !rhs.isVector() && lhs.scalars() == rhs.scalars();
}
}
} // namespace
namespace llvm {
template <>
struct DenseMapInfo<ShapeValue> {
static ShapeValue getEmptyKey() {
return ShapeValue(DenseMapInfo<mlir::Value>::getEmptyKey());
}
static ShapeValue getTombstoneKey() {
return ShapeValue(DenseMapInfo<mlir::Value>::getTombstoneKey());
}
static unsigned getHashValue(ShapeValue shape) { return hash_value(shape); }
static bool isEqual(ShapeValue LHS, ShapeValue RHS) { return LHS == RHS; }
};
} // namespace llvm
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
// A basic shape equality inference. This should be superceeded by a proper
// inference once available. Until then, we just build this out to the needs of
// the kernel generator project.
class ShapeEqualityKnowledge {
public:
/// Checks all operations for potential shape equality of their respective
/// results.
void build(FuncOp function) {
function.walk([&](Operation *op) {
if (auto reshape = dyn_cast<lmhlo::ReshapeMemRefCastOp>(op)) {
registerAssociation(ShapeValue{reshape.operand()}, reshape.result());
return;
}
if (auto alloc = dyn_cast<AllocOp>(op)) {
SmallVector<ValueOrConst, 4> shape;
ShapedType type = alloc.getResult().getType().cast<ShapedType>();
fillShapeFromAllocLike(alloc.getDynamicSizes(), type, shape);
registerAssociation(ShapeValue{shape}, alloc.getResult());
return;
}
if (auto alloc = dyn_cast<tf_framework::TFAllocOp>(op)) {
// Construct a symbol representing the allocated shape.
SmallVector<ValueOrConst, 4> shape;
ShapedType type = alloc.getResult().getType().cast<ShapedType>();
fillShapeFromAllocLike(alloc.dyn_sizes(), type, shape);
registerAssociation(ShapeValue{shape}, alloc.getResult());
return;
}
});
}
/// Checks whether `one` and `other` are known to have the same shape and
/// strides.
bool haveSameShape(Value one, Value other) {
return equal_shapes_.isEquivalent(one.getAsOpaquePointer(),
other.getAsOpaquePointer());
}
private:
static void fillShapeFromAllocLike(mlir::OperandRange operands,
ShapedType type,
SmallVectorImpl<ValueOrConst> &shape) {
assert(type.hasRank());
auto dynamic_sizes = operands.begin();
for (auto extent : type.getShape()) {
shape.push_back(ShapedType::isDynamic(extent)
? ValueOrConst{*(dynamic_sizes++)}
: ValueOrConst{extent});
}
}
/// Registers the value `value` to have the shape represented by `shape`. If
/// `shape` has been registered before, place `value` into the same
/// equivalence class. Otherwise register `value` as an equivalence class of
/// its own.
void registerAssociation(ShapeValue shape, Value value) {
auto insert_symbolic = symbolic_shapes_.insert({shape, value});
if (insert_symbolic.second) {
equal_shapes_.insert(value.getAsOpaquePointer());
// We have seen this symbolic shape for the first time. Try to match it
// with a vector or shape we already know and alias classes if possible.
// This could be based on shape dialect if we weren't late in the
// lowering.
tryEvaluateShapeToRoot(shape, value);
} else {
equal_shapes_.unionSets(
insert_symbolic.first->second.getAsOpaquePointer(),
value.getAsOpaquePointer());
}
}
/// Follows the definition chains of the ShapeValue `shape` to identify cases
/// where `shape` is derived from some other value's shape. In such case, the
/// equivalence classes of that other value and `value` are unioned.
void tryEvaluateShapeToRoot(ShapeValue shape, Value value) {
// Just some pattern matching for common cases here.
if (!shape.isVector()) {
// Patterns that revolve around scalars.
// Check whether the scalars are all dim operations for some other memref.
// TODO(herhut): Use pattern match infra here.
Value candidate;
for (auto extent : llvm::enumerate(shape.scalars())) {
if (extent.value().isConstant()) {
candidate = {};
break;
}
if (auto dimOp = extent.value().value().getDefiningOp<mlir::DimOp>()) {
auto dimIndex = dimOp.getConstantIndex();
if (!dimIndex.hasValue() || (dimIndex.getValue() != extent.index())) {
candidate = {};
break;
}
if (candidate && candidate != dimOp.memrefOrTensor()) {
candidate = {};
break;
}
candidate = dimOp.memrefOrTensor();
}
}
if (candidate) {
equal_shapes_.unionSets(candidate.getAsOpaquePointer(),
value.getAsOpaquePointer());
}
} else {
// Patterns that revovlve around vector representation.
}
}
// These are values with identical shapes (or rather their opaque pointers).
llvm::EquivalenceClasses<void *> equal_shapes_;
// A map from a value that encodes a shape to a value that has this shape.
llvm::DenseMap<ShapeValue, Value> symbolic_shapes_;
};
/// For arguments to kernels that have the same shape, use the stride and
/// shape information of the left-most argument inside of the kernel function.
/// That way, llvm can CSE index computations on same-shaped inputs.
struct PropagateShapeKnowledgeToKernels
: public PropagateShapeKnowledgeToKernelsBase<
PropagateShapeKnowledgeToKernels> {
void runOnFunction() override {
ShapeEqualityKnowledge knowledge;
knowledge.build(getFunction());
getFunction().walk([&](gpu::LaunchFuncOp launch) {
auto module = launch.getParentOfType<ModuleOp>();
auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
if (!kernel || kernel.isExternal()) return;
llvm::SmallVector<std::pair<Value, int>, 4> seen_memrefs;
int kernel_p = 0;
for (auto operand : launch.operands()) {
auto memref = operand.getType().dyn_cast<MemRefType>();
if (!memref) {
// Scalar argument, advance kernel position by one.
kernel_p++;
continue;
}
for (auto previous : seen_memrefs) {
if (!knowledge.haveSameShape(operand, previous.first)) {
continue;
}
// We use the first equality found and replace uses of corresponding
// size and stride information here.
// TODO(herhut): This is not safe if we had a cast operation
// inbetween that changes stride information. The current
// analysis above would not consider this equal.
// We need to replace sizes and strides.
auto args_to_replace = memref.getRank() * 2;
int previous_args_pos = previous.second;
auto previous_args = kernel.getArguments()
.drop_front(previous_args_pos + 3)
.take_front(args_to_replace);
auto current_args = kernel.getArguments()
.drop_front(kernel_p + 3)
.take_back(args_to_replace);
for (auto pair : llvm::zip(previous_args, current_args)) {
std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair));
}
break;
}
seen_memrefs.push_back({operand, kernel_p});
// Advance base, aligned, offset, strides and sizes many arguments.
kernel_p += memref.getRank() * 2 + 3;
}
});
}
};
} // namespace
std::unique_ptr<FunctionPass> CreatePropagateShapeKnowledgeToKernels() {
return std::make_unique<PropagateShapeKnowledgeToKernels>();
}
} // namespace transforms
} // namespace kernel_gen
} // namespace mlir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册