提交 c8cee324 编写于 作者: J Johannes Reifferscheid 提交者: TensorFlower Gardener

Triton emitter: remove some unnecessary helpers/templates.

PiperOrigin-RevId: 563735974
上级 d8359707
......@@ -487,6 +487,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/mlir_hlo",
"//xla/mlir_hlo:map_mhlo_to_scalar_op",
"//xla/service:dump",
"//xla/service/gpu/llvm_gpu_backend",
......
......@@ -64,6 +64,7 @@ limitations under the License.
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/ValueRange.h" // from @llvm-project
......@@ -86,6 +87,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/literal.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
#include "xla/primitive_util.h"
#include "xla/service/dump.h"
......@@ -214,39 +216,13 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) {
<< llvm_ir::DumpToString(dst_element_ty);
}
Type ElementType(Value v) {
Type src_ty = v.getType();
if (auto src_shaped_ty = src_ty.dyn_cast<mlir::ShapedType>()) {
return src_shaped_ty.getElementType();
}
return src_ty;
}
// Get the value of the scalar constant's literal in a C++ type.
template <typename T>
T ScalarConstantValue(const HloInstruction& instr) {
T ScalarConstantValue(const HloInstruction& instr, PrimitiveType dst_type) {
CHECK(hlo_query::IsScalarConstant(&instr));
PrimitiveType dst_type;
if constexpr (std::is_integral_v<T>) {
if constexpr (std::numeric_limits<T>::is_signed) {
dst_type = S64;
} else {
dst_type = U64;
}
} else {
dst_type = F64;
}
StatusOr<Literal> converted = instr.literal().Convert(dst_type);
TF_CHECK_OK(converted.status());
if constexpr (std::is_integral_v<T>) {
if constexpr (std::numeric_limits<T>::is_signed) {
return converted.value().GetFirstElement<int64_t>();
} else {
return converted.value().GetFirstElement<uint64_t>();
}
} else {
return converted.value().GetFirstElement<double>();
}
return converted.value().GetFirstElement<T>();
}
// Create a scalar constant.
......@@ -279,7 +255,7 @@ ma::ConstantOp CreateConst(ImplicitLocOpBuilder& b, Type type, T value,
}
Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) {
if (ElementType(values[0]).isa<mlir::IntegerType>()) {
if (mlir::getElementTypeOrSelf(values[0]).isa<mlir::IntegerType>()) {
return b.create<ma::SubIOp>(values[0], values[1]);
} else {
return b.create<ma::SubFOp>(values[0], values[1]);
......@@ -287,34 +263,28 @@ Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) {
}
Value Compare(ImplicitLocOpBuilder& b, ValueRange values,
ComparisonDirection direction) {
if (ElementType(values[0]).isa<mlir::IntegerType>()) {
mlir::mhlo::ComparisonDirection direction) {
if (mlir::getElementTypeOrSelf(values[0]).isa<mlir::IntegerType>()) {
return b.create<ma::CmpIOp>(
mlir::mhlo::impl::getCmpPredicate<ma::CmpIPredicate>(
mlir::mhlo::symbolizeComparisonDirection(
ComparisonDirectionToString(direction))
.value(),
/*isSigned=*/true)
mlir::mhlo::impl::getCmpPredicate<ma::CmpIPredicate>(direction,
/*isSigned=*/true)
.value(),
values[0], values[1]);
}
return b.create<ma::CmpFOp>(
mlir::mhlo::impl::getCmpPredicate<ma::CmpFPredicate>(
mlir::mhlo::symbolizeComparisonDirection(
ComparisonDirectionToString(direction))
.value(),
/*isSigned=*/true)
mlir::mhlo::impl::getCmpPredicate<ma::CmpFPredicate>(direction,
/*isSigned=*/true)
.value(),
values[0], values[1]);
}
Value Maximum(ImplicitLocOpBuilder& b, ValueRange values) {
auto cmp = Compare(b, values, ComparisonDirection::kGt);
auto cmp = Compare(b, values, mlir::mhlo::ComparisonDirection::GT);
return b.create<ma::SelectOp>(cmp, values[0], values[1]);
}
Value Minimum(ImplicitLocOpBuilder& b, ValueRange values) {
auto cmp = Compare(b, values, ComparisonDirection::kLt);
auto cmp = Compare(b, values, mlir::mhlo::ComparisonDirection::LT);
return b.create<ma::SelectOp>(cmp, values[0], values[1]);
}
......@@ -345,9 +315,7 @@ using TensorValue = mlir::TypedValue<mlir::RankedTensorType>;
Value Broadcast(ImplicitLocOpBuilder& b, TensorValue value,
ArrayRef<int64_t> shape) {
auto type =
mlir::RankedTensorType::get(shape, value.getType().getElementType());
return b.create<mt::BroadcastOp>(type, value);
return b.create<mt::BroadcastOp>(value.getType().clone(shape), value);
}
Value Range(ImplicitLocOpBuilder& b, int32_t limit) {
......@@ -361,7 +329,8 @@ Value AddPtr(ImplicitLocOpBuilder& b, Value ptr, Value offset) {
Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
const HloInstruction& hlo, ValueRange inputs) {
if (ElementType(inputs[0]).isF32() || ElementType(inputs[0]).isF64()) {
if (mlir::getElementTypeOrSelf(inputs[0]).isF32() ||
mlir::getElementTypeOrSelf(inputs[0]).isF64()) {
auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode());
if (dev_fn_id.ok()) {
return b.create<mt::PureExternElementwiseOp>(
......@@ -371,7 +340,8 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
llvm::Triple("nvptx64-unknown-unknown")));
}
}
const bool is_integer = ElementType(inputs[0]).isa<mlir::IntegerType>();
const bool is_integer =
mlir::getElementTypeOrSelf(inputs[0]).isa<mlir::IntegerType>();
switch (hlo.opcode()) {
case HloOpcode::kCopy:
......@@ -418,11 +388,15 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
}
return b.create<ma::DivFOp>(inputs[0], inputs[1]);
case HloOpcode::kCompare:
return Compare(b, inputs, hlo.comparison_direction());
return Compare(
b, inputs,
mlir::mhlo::symbolizeComparisonDirection(
ComparisonDirectionToString(hlo.comparison_direction()))
.value());
case HloOpcode::kSelect:
return b.create<ma::SelectOp>(
Compare(b, {inputs[0], ZerosLike(b, inputs[0])},
ComparisonDirection::kNe),
mlir::mhlo::ComparisonDirection::NE),
inputs[1], inputs[2]);
default:
LOG(FATAL) << "Unsupported operation " << hlo.ToString();
......@@ -450,12 +424,12 @@ Value EmitConstant(ImplicitLocOpBuilder& b, const HloInstruction& constant) {
Type ty = TritonType(b, constant.shape().element_type());
if (constant.shape().IsInteger()) {
if (constant.shape().element_type() == U64) {
return CreateConst(b, ty, ScalarConstantValue<uint64_t>(constant));
return CreateConst(b, ty, ScalarConstantValue<uint64_t>(constant, U64));
} else {
return CreateConst(b, ty, ScalarConstantValue<int64_t>(constant));
return CreateConst(b, ty, ScalarConstantValue<int64_t>(constant, S64));
}
}
return CreateConst(b, ty, ScalarConstantValue<double>(constant));
return CreateConst(b, ty, ScalarConstantValue<double>(constant, F64));
}
struct DimProperties {
......@@ -755,13 +729,27 @@ struct GeneralizeKernelSignaturePass
}
};
} // namespace
// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n].
// TODO(b/270937368): Split this up into smaller functions.
template <typename IndexT>
StatusOr<LaunchDimensions> MatMulImpl(
mlir::OpBuilder builder, absl::string_view libdevice_path,
const HloDotInstruction* dot_instr, mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config, int shmem_budget) {
StatusOr<LaunchDimensions> MatMul(mlir::OpBuilder builder,
absl::string_view libdevice_path,
const HloComputation* computation,
mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget) {
const HloDotInstruction* dot_instr = DynCast<HloDotInstruction>(
hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot));
// Use 32-bit indexing if addressing any of the inputs or the output (which
// could grow if split_k is set) does not cross the INT_MAX boundary.
// Otherwise, fall back to 64-bit indexing, which is slower.
bool use_64bit_indexing =
ShapeUtil::ElementsIn(dot_instr->operand(0)->shape()) > INT_MAX ||
ShapeUtil::ElementsIn(dot_instr->operand(1)->shape()) > INT_MAX ||
ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k() > INT_MAX;
mlir::Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32);
const HloInstruction* root = dot_instr->parent()->root_instruction();
CHECK(!root->shape().IsTuple());
......@@ -772,12 +760,6 @@ StatusOr<LaunchDimensions> MatMulImpl(
ImplicitLocOpBuilder b(loc, builder);
Type i32_ty = b.getI32Type();
Type i64_ty = b.getI64Type();
Type int_ty;
if constexpr (std::is_same_v<IndexT, int64_t>) {
int_ty = i64_ty;
} else {
int_ty = i32_ty;
}
const int split_k = config.split_k();
const int block_m = config.block_m();
......@@ -927,8 +909,8 @@ StatusOr<LaunchDimensions> MatMulImpl(
// Extend int32 indexes to int64, if necessary.
auto convert_scalar = [&](Value value) -> Value {
if constexpr (std::is_same_v<IndexT, int64_t>) {
return b.create<ma::ExtSIOp>(int_ty, value);
if (index_ty.getIntOrFloatBitWidth() == 64) {
return b.create<ma::ExtSIOp>(index_ty, value);
}
return value;
};
......@@ -1059,7 +1041,7 @@ StatusOr<LaunchDimensions> MatMulImpl(
b.create<mt::ExpandDimsOp>(range_k, 0),
Splat(b, elements_in_tile, {1, block_k}))
.getResult()
.template cast<TensorValue>(),
.cast<TensorValue>(),
{block_m, block_k});
Value rhs_mask = Broadcast(
b,
......@@ -1067,7 +1049,7 @@ StatusOr<LaunchDimensions> MatMulImpl(
b.create<mt::ExpandDimsOp>(range_k, 1),
Splat(b, elements_in_tile, {block_k, 1}))
.getResult()
.template cast<TensorValue>(),
.cast<TensorValue>(),
{block_k, block_n});
dot_input_lhs = b.create<ma::SelectOp>(lhs_mask, dot_input_lhs,
ZerosLike(b, dot_input_lhs));
......@@ -1131,7 +1113,7 @@ StatusOr<LaunchDimensions> MatMulImpl(
add_dim(dim);
}
IndexT stride_batch = 0;
int64_t stride_batch = 0;
if (scope != TritonFusionAnalysis::Scope::RHS && lhs_nc_split) {
const TensorIterationSpec::DimIterationSpec* spec =
analysis.IterSpec(scope, hlo, tiled_dimensions[0].index);
......@@ -1158,8 +1140,9 @@ StatusOr<LaunchDimensions> MatMulImpl(
}
}
if (stride_batch != 0) {
Value offset_batch = b.create<ma::MulIOp>(
convert_scalar(pid_batch), CreateConst(b, int_ty, stride_batch));
Value offset_batch =
b.create<ma::MulIOp>(convert_scalar(pid_batch),
CreateConst(b, index_ty, stride_batch));
base = AddPtr(b, base, offset_batch);
}
......@@ -1167,9 +1150,10 @@ StatusOr<LaunchDimensions> MatMulImpl(
const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec(
TritonFusionAnalysis::Scope::OUTPUT, hlo, split_k_out_idx);
if (spec != nullptr) {
IndexT stride_split_k = spec->at(0).stride;
Value offset_split_k = b.create<ma::MulIOp>(
convert_scalar(pid_k), CreateConst(b, int_ty, stride_split_k));
int64_t stride_split_k = spec->at(0).stride;
Value offset_split_k =
b.create<ma::MulIOp>(convert_scalar(pid_k),
CreateConst(b, index_ty, stride_split_k));
base = AddPtr(b, base, offset_split_k);
}
}
......@@ -1283,32 +1267,6 @@ StatusOr<LaunchDimensions> MatMulImpl(
{config.num_warps() * WarpSize(), 1, 1}};
}
} // namespace
StatusOr<LaunchDimensions> MatMul(mlir::OpBuilder builder,
absl::string_view libdevice_path,
const HloComputation* computation,
mlir::triton::FuncOp fn,
const AutotuneResult::TritonGemmKey& config,
int shmem_budget) {
const HloDotInstruction* dot_instr = DynCast<HloDotInstruction>(
hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot));
// Use 32-bit indexing if addressing any of the inputs or the output (which
// could grow if split_k is set) does not cross the INT_MAX boundary.
// Otherwise, fall back to 64-bit indexing, which is slower.
bool use_64bit_indexing =
ShapeUtil::ElementsIn(dot_instr->operand(0)->shape()) > INT_MAX ||
ShapeUtil::ElementsIn(dot_instr->operand(1)->shape()) > INT_MAX ||
ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k() > INT_MAX;
if (use_64bit_indexing) {
return MatMulImpl<int64_t>(builder, libdevice_path, dot_instr, fn, config,
shmem_budget);
} else {
return MatMulImpl<int32_t>(builder, libdevice_path, dot_instr, fn, config,
shmem_budget);
}
}
StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder builder,
absl::string_view libdevice_path,
const HloComputation* computation,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册