提交 6158b870 编写于 作者: G George Karpenkov 提交者: TensorFlower Gardener

[NFC] [XLA:GPU] Provide a ToString() method and logging for TritonFusionAnalysis

Sample output:

```
ir_emitter_triton.cc:755] TritonFusionAnalysis{
LHS: IterSpec{parameter_0.1: {0: [{stride=128, count=480, subfragments=[480]}], 1: [{stride=1, count=128, subfragments=[128]}]}}
RHS: IterSpec{parameter_1.1: {1: [{stride=1, count=128, subfragments=[128]}], 0: [{stride=128, count=16, subfragments=[16]}]}}
OUTPUT: IterSpec{dot.1: {0: [{stride=1, count=480, subfragments=[480]}], 1: [{stride=480, count=16, subfragments=[16]}]}}
}
```
PiperOrigin-RevId: 561052627
上级 3530d58a
......@@ -73,6 +73,12 @@ class TensorIterationSpec {
// Logical subfragments when this iteration is composed
// of several HLO dimensions. Product of subfragments equals `count`.
std::vector<int64_t> subfragments;
std::string ToString() const {
return absl::StrCat("{stride=", stride, ", count=", count,
", subfragments=[", absl::StrJoin(subfragments, ", "),
"]}");
}
};
// Description of complex iteration over a sequence of several strides.
// Describes a logically contiguous dimension of a tensor physically
......@@ -99,14 +105,28 @@ class TensorIterationSpec {
// Compares physical layouts of tensors ignoring subfragments of dimensions.
bool operator==(const TensorIterationSpec& other) const;
std::string ToString() const {
return absl::StrCat(
"{",
absl::StrJoin(dim_iteration_specs_, ", ",
[&](std::string* s, const auto& kv) {
absl::StrAppend(
s, kv.first, ": ", "[",
absl::StrJoin(kv.second, ", ",
[&](std::string* ss, const auto& v) {
absl::StrAppend(ss, v.ToString());
}),
"]");
}),
"}");
}
private:
StorageType dim_iteration_specs_;
};
// Analysis of tensor iteration orders within tiled fusions.
class TritonFusionAnalysis {
TritonFusionAnalysis() {}
Status ExecuteForDotFusion(const HloInstruction& dot, int split_k);
public:
......@@ -136,10 +156,48 @@ class TritonFusionAnalysis {
return parameters_.at(scope);
}
std::string ToString() const {
return absl::StrCat(
"TritonFusionAnalysis{\n",
absl::StrJoin(iter_specs_, ",\n",
[&](std::string* s, const auto& kv) {
absl::StrAppend(
s, ScopeToString(kv.first), ": ",
IterationSpecByInstructionMapToString(kv.second));
}),
"\n}");
}
private:
absl::flat_hash_map<
Scope, absl::flat_hash_map<const HloInstruction*, TensorIterationSpec>>
iter_specs_;
using IterationSpecByInstructionMap =
absl::flat_hash_map<const HloInstruction*, TensorIterationSpec>;
using IterationSpecByInstructionByScopeMap =
absl::flat_hash_map<Scope, IterationSpecByInstructionMap>;
static std::string IterationSpecByInstructionMapToString(
const IterationSpecByInstructionMap& m) {
return absl::StrCat("IterSpec{",
absl::StrJoin(m, ", ",
[&](std::string* s, const auto& kv) {
absl::StrAppend(s, kv.first->name(),
": ",
kv.second.ToString());
}),
"}");
}
static std::string ScopeToString(Scope s) {
switch (s) {
case Scope::LHS:
return "LHS";
case Scope::RHS:
return "RHS";
case Scope::OUTPUT:
return "OUTPUT";
}
}
IterationSpecByInstructionByScopeMap iter_specs_;
// HLO computation parameters per scope.
absl::flat_hash_map<Scope, absl::flat_hash_set<const HloInstruction*>>
parameters_;
......
......@@ -750,8 +750,10 @@ StatusOr<LaunchDimensions> MatMulImpl(
CHECK_GE(block_n, 16);
const DotDimensionNumbers& dims = dot_instr->dot_dimension_numbers();
TF_ASSIGN_OR_RETURN(const auto analysis, TritonFusionAnalysis::Execute(
*dot_instr->parent(), split_k));
TF_ASSIGN_OR_RETURN(
const TritonFusionAnalysis analysis,
TritonFusionAnalysis::Execute(*dot_instr->parent(), split_k));
VLOG(6) << analysis.ToString();
// Rely on dot decomposer: there is just one contracting and one
// non-contracting dimension on each side + batch ones optionally.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册