提交 870f527a 编写于 作者: T Tamás Danyluk 提交者: TensorFlower Gardener

[XLA:GPU] Add a flag to dump autotuned Triton fusions

This will be good when we want to analyze all Triton fusions of an HLO for example with ncu.

PiperOrigin-RevId: 549579171
上级 63e6a63d
......@@ -1176,6 +1176,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(&DebugOptions::set_xla_gpu_triton_fusion_level),
debug_options->xla_gpu_triton_fusion_level(),
"Triton fusion level, higher levels mean more fused operations."));
flag_list->push_back(tsl::Flag(
"xla_gpu_dump_autotuned_triton_fusions",
bool_setter_for(&DebugOptions::set_xla_gpu_dump_autotuned_triton_fusions),
debug_options->xla_gpu_dump_autotuned_triton_fusions(),
"Dumps autotuned Triton fusions to the directory specified by "
"xla_dump_to or stdout. Each fusion is dumped only once, as an optimized "
"HLO."));
} // NOLINT(readability/fn_size)
// Allocates flag_values and flag_objects; this function must not be called more
......
......@@ -603,6 +603,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:float_normalization",
"//tensorflow/compiler/xla/service:hlo_module_config",
......
......@@ -27,6 +27,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
......@@ -37,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/float_normalization.h"
#include "tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h"
......@@ -253,6 +255,26 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor {
TF_ASSIGN_OR_RETURN(
AutotuneResult best,
PickBestResult(results, root->ToString(), root->GetModule()->config()));
if (debug_opts.xla_gpu_dump_autotuned_triton_fusions()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> module,
TritonGemmAutotuneExtractor(best.triton(),
GetGpuDeviceInfo(config_.GetExecutor()),
fusion.FusionInstruction()));
module->set_name(std::string(fusion.FusionInstruction()->name()));
// Using the original module for its debug info and name in the first
// parameter. It's better to include the name of both the original module
// and the extracted module, to avoid name clashes.
DumpToFileInDirOrStdout(
/*module=*/*instr->GetModule(),
/*file_prefix=*/"",
/*file_suffix=*/
absl::StrCat("triton_fusion_", fusion_id_for_dump_++, ".",
module->name(), ".optimized.txt"),
/*contents=*/module->ToString());
}
return best;
}
......@@ -362,6 +384,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor {
AutotuneConfig config_;
tsl::thread::ThreadPool* thread_pool_;
std::optional<AutotunerCompileUtil> autotuner_compile_util_;
int fusion_id_for_dump_ = 0;
};
// Search space for exhaustive matmul autotuning.
......
......@@ -578,7 +578,9 @@ message DebugOptions {
int32 xla_gpu_triton_fusion_level = 229;
// Next id: 232
bool xla_gpu_dump_autotuned_triton_fusions = 232;
// Next id: 233
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册