提交 b710b1e0 编写于 作者: I Ilia Sergachev 提交者: TensorFlower Gardener

[XLA:GPU][NFC] Refactor Triton GEMM rewriter.

Add a comparison operator for fragments, improve encapsulation, fix comments.

PiperOrigin-RevId: 564990063
上级 55bae7da
......@@ -103,12 +103,7 @@ bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const {
return false;
}
for (int fragment = 0; fragment < it_this->second.size(); ++fragment) {
if (it_this->second.size() != it_other->second.size()) {
return false;
}
if (it_this->second[fragment].stride !=
it_other->second[fragment].stride ||
it_this->second[fragment].count != it_other->second[fragment].count) {
if (it_this->second[fragment] != it_other->second[fragment]) {
return false;
}
}
......@@ -209,12 +204,12 @@ class DimensionOrder {
<< "The split-K batch dimension has be preceded by the contracting "
"dimension it originates from by construction.";
target_dim_number =
dim_order.tensor_fragments_order_.back().dst_dim_number;
dim_order.tensor_fragments_order_.back().dst_dim_number();
}
dim_order.dim_fragments_orders_[target_dim_number].push_back(
dim_order.tensor_fragments_order_.size());
dim_order.tensor_fragments_order_.push_back(
{target_dim_number, hlo.shape().dimensions(i)});
Fragment{target_dim_number, hlo.shape().dimensions(i)});
}
return dim_order;
}
......@@ -225,27 +220,34 @@ class DimensionOrder {
dim_order.dim_fragments_orders_[kSoftmaxReductionDimension].push_back(
dim_order.tensor_fragments_order_.size());
dim_order.tensor_fragments_order_.push_back(
{kSoftmaxReductionDimension, hlo.shape().dimensions_minor(0)});
Fragment{kSoftmaxReductionDimension, hlo.shape().dimensions_minor(0)});
for (int i = 1; i < hlo.shape().rank(); ++i) {
dim_order.dim_fragments_orders_[kSoftmaxBatchDimension].push_back(
dim_order.tensor_fragments_order_.size());
dim_order.tensor_fragments_order_.push_back(
{kSoftmaxBatchDimension, hlo.shape().dimensions_minor(i)});
Fragment{kSoftmaxBatchDimension, hlo.shape().dimensions_minor(i)});
}
return dim_order;
}
// Description of a continuous fragment of one dimension of a tensor.
struct Fragment {
// Label carrying the dimension number of an defining operation.
int dst_dim_number;
// Number of elements in the fragment.
int64_t size;
class Fragment {
public:
explicit Fragment(int dst_dim_number, int64_t size)
: dst_dim_number_(dst_dim_number), size_(size) {}
std::string ToString() const {
return absl::StrCat(dst_dim_number, ":", size);
return absl::StrCat(dst_dim_number_, ":", size_);
}
Fragment(int dst_dim_number, int64_t size)
: dst_dim_number(dst_dim_number), size(size) {}
// Label carrying the dimension number of an defining operation.
int dst_dim_number() const { return dst_dim_number_; }
// Total number of elements in the fragment ignoring slicing.
int64_t size() const { return size_; }
void set_size(int64_t size) { size_ = size; }
private:
const int dst_dim_number_;
int64_t size_;
};
using Fragments = std::vector<Fragment>;
using FragmentOrders = absl::flat_hash_map<int, std::vector<int>>;
......@@ -311,29 +313,30 @@ TensorIterationSpec DimensionOrderToTensorIterationSpec(
for (int dim_order_index = 0; dim_order_index < dim_fragments.size();
++dim_order_index) {
const DimensionOrder::Fragment& fragment = dim_fragments[dim_order_index];
VLOG(6) << fragment.dst_dim_number << "\t" << fragment.size;
VLOG(6) << fragment.dst_dim_number() << "\t" << fragment.size();
DimIterationSpec& dim_spec = tensor_spec[fragment.dst_dim_number];
if (last_dim == fragment.dst_dim_number) {
// Contiguous dimension, split only logically. Merge it back.
DimIterationSpec& dim_spec = tensor_spec[fragment.dst_dim_number()];
if (last_dim == fragment.dst_dim_number()) {
// Remove previous 1-sized subfragment if present.
if (!dim_spec.empty() && !dim_spec.back().subfragments.empty() &&
dim_spec.back().subfragments.back() == 1) {
// Remove previous 1-sized subfragment.
dim_spec.back().subfragments.pop_back();
}
if (fragment.size > 1) {
// Contiguous dimension, split only logically. Merge it back.
if (fragment.size() > 1) {
CHECK(!dim_spec.empty());
dim_spec.back().count *= fragment.size;
dim_spec.back().subfragments.push_back(fragment.size);
dim_spec.back().count *= fragment.size();
dim_spec.back().subfragments.push_back(fragment.size());
}
} else {
remove_last_fragment_if_degenerate(last_dim);
// Add part of the dimension.
dim_spec.push_back({accumulated_stride, fragment.size, {fragment.size}});
dim_spec.push_back(TensorIterationSpec::IterationSpecFragment{
accumulated_stride, fragment.size(), {fragment.size()}});
}
accumulated_stride *= fragment.size;
last_dim = fragment.dst_dim_number;
accumulated_stride *= fragment.size();
last_dim = fragment.dst_dim_number();
}
remove_last_fragment_if_degenerate(last_dim);
tensor_spec.RemoveEmptyDimensions();
......@@ -532,12 +535,12 @@ FusionDecision FusionContext::RequireSupportedDimOrder(
if (fragment == dim_fragments.cend()) {
break;
}
int64_t grouped_size = tensor_dim_fragments[*fragment].size;
int64_t grouped_size = tensor_dim_fragments[*fragment].size();
// Gather contiguous fragments.
while ((fragment + 1) != dim_fragments.cend() &&
*(fragment + 1) == *fragment + 1) {
++fragment;
grouped_size *= tensor_dim_fragments[*fragment].size;
grouped_size *= tensor_dim_fragments[*fragment].size();
}
if (grouped_size == 1) {
......@@ -649,7 +652,7 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast(
src_to_dst[&*src_dim].push_back(dst_fragments_order.size() - 1);
};
if (std::holds_alternative<SoftmaxProperties>(properties_) &&
src_dim->dst_dim_number ==
src_dim->dst_dim_number() ==
std::get<SoftmaxProperties>(properties_).softmax_batch_dimension) {
// Special handling for softmax batch dimension: allow arbitrary reshapes
// on it because it's guaranteed by the construction of the fusion to have
......@@ -657,23 +660,23 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast(
// Find a continuous group of fragments corresponding to this dimension in
// the source and assign the corresponding size in fragments of the
// destination ignoring the source ones.
dst_remaining_size = src_dim->size;
dst_remaining_size = src_dim->size();
while (src_dim + 1 != src_fragments_order.cend() &&
(src_dim + 1)->dst_dim_number == src_dim->dst_dim_number) {
(src_dim + 1)->dst_dim_number() == src_dim->dst_dim_number()) {
++src_dim;
dst_remaining_size *= src_dim->size;
dst_remaining_size *= src_dim->size();
}
while (dst_remaining_size > 1) {
CHECK(dst_dim_it != dst_dim_end);
add_new_fragment(
{src_dim->dst_dim_number, dst_shape.dimensions(*dst_dim_it)});
add_new_fragment(Fragment{src_dim->dst_dim_number(),
dst_shape.dimensions(*dst_dim_it)});
dst_remaining_size /= dst_shape.dimensions(*dst_dim_it);
++dst_dim_it;
}
continue;
}
if (dst_remaining_size >= src_dim->size) {
if (dst_remaining_size % src_dim->size) {
if (dst_remaining_size >= src_dim->size()) {
if (dst_remaining_size % src_dim->size()) {
return "Unsupported bitcast";
}
// Source dimension fragment completely fits into the destination one:
......@@ -681,12 +684,12 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast(
add_new_fragment(*src_dim);
// Update the size of the remaining part of the destination that is
// carried over to next source dimensions.
dst_remaining_size /= src_dim->size;
dst_remaining_size /= src_dim->size();
} else {
// Source is larger than destination.
// Assign further destination dimensions.
// Size of the not yet assigned part of the source dimension.
int64_t src_remaining_size = src_dim->size;
int64_t src_remaining_size = src_dim->size();
// Handle dimension splits.
if (dst_remaining_size > 1) {
// If there is a remaining fragment of a previous destination dimension
......@@ -694,7 +697,8 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast(
if (src_remaining_size % dst_remaining_size) {
return "Unsupported bitcast";
}
add_new_fragment({src_dim->dst_dim_number, dst_remaining_size});
add_new_fragment(
Fragment{src_dim->dst_dim_number(), dst_remaining_size});
// Update the size of the fragment remaining to assign.
src_remaining_size /= dst_remaining_size;
dst_remaining_size = 1;
......@@ -714,7 +718,8 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast(
dst_remaining_size = dst_dim_size / src_remaining_size;
new_fragment_size = src_remaining_size;
}
add_new_fragment({src_dim->dst_dim_number, new_fragment_size});
add_new_fragment(
Fragment{src_dim->dst_dim_number(), new_fragment_size});
src_remaining_size /= new_fragment_size;
++dst_dim_it;
}
......@@ -731,7 +736,7 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast(
}
if (!dst_fragments_order.empty()) {
dst_fragments_order.push_back(
{dst_fragments_order.back().dst_dim_number, 1});
Fragment{dst_fragments_order.back().dst_dim_number(), 1});
src_to_dst[&src_fragments_order.back()].push_back(
dst_fragments_order.size() - 1);
}
......@@ -788,7 +793,7 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp(
std::vector<const Fragment*> subdim_group;
do {
CHECK(src_fragment_it != src_fragment_end);
subdim_size_accumulator *= src_fragment_it->size;
subdim_size_accumulator *= src_fragment_it->size();
subdim_group.push_back(&*src_fragment_it);
++src_fragment_it;
} while (subdim_size_accumulator < dim_size);
......@@ -860,11 +865,11 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp(
for (const Fragment* subdim : dst_logical[dim_idx]) {
dst_fragments_order.push_back(*subdim);
src_to_dst[subdim] = dst_fragments_order.size() - 1;
dim_numbers_present_in_dst.insert(subdim->dst_dim_number);
dim_numbers_present_in_dst.insert(subdim->dst_dim_number());
if (std::holds_alternative<SoftmaxProperties>(properties_) &&
subdim->dst_dim_number == std::get<SoftmaxProperties>(properties_)
subdim->dst_dim_number() == std::get<SoftmaxProperties>(properties_)
.softmax_reduction_dimension) {
dst_dim_fragments_order[subdim->dst_dim_number].push_back(
dst_dim_fragments_order[subdim->dst_dim_number()].push_back(
dst_fragments_order.size() - 1);
}
}
......@@ -874,7 +879,7 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp(
for (const int fragment_number : dim_sequence) {
const auto it = src_to_dst.find(&src_fragments_order[fragment_number]);
if (it == src_to_dst.cend()) {
if (src_fragments_order[fragment_number].size > 1 &&
if (src_fragments_order[fragment_number].size() > 1 &&
dim_numbers_present_in_dst.contains(dim_index)) {
return FusionDecision("Unsupported broadcast");
}
......@@ -1119,9 +1124,6 @@ void FusionContext::TryToFuseWithInputsRecursively(
// become parameters of the fusion. Used to track the number of parameters
// of the fusion.
absl::flat_hash_set<const HloInstruction*> inputs;
// Currently only one physically unique dim order per scope is supported.
// Let it change while the scope has one input; afterwards require all
// of them to be physically compatible.
auto try_fuse_one = [&](HloInstruction& hlo) {
const DimOrderUpdatesOrError result = AnalyzeForFusion(
hlo, /*as_input=*/true, old_to_new_mapping, gpu_version);
......@@ -1896,6 +1898,11 @@ std::string TensorIterationSpec::IterationSpecFragment::ToString() const {
absl::StrJoin(subfragments, ", "), "]}");
}
bool TensorIterationSpec::IterationSpecFragment::operator!=(
const IterationSpecFragment& other) const {
return stride != other.stride || count != other.count;
}
std::string TensorIterationSpec::ToString() const {
return absl::StrCat(
"{",
......
......@@ -81,6 +81,7 @@ class TensorIterationSpec {
// of several HLO dimensions. Product of subfragments equals `count`.
std::vector<int64_t> subfragments;
bool operator!=(const IterationSpecFragment& other) const;
std::string ToString() const;
};
// Description of complex iteration over a sequence of several strides.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册