提交 62bd5df8 编写于 作者: E Eugene Zhulenev 提交者: TensorFlower Gardener

[xla:runtime] Use structs to define gpu runtime custom calls

LLVM seems to have troubles with inlining custom call handlers defined by function pointers. When we use struct, then the custom call body is typically fully inlined into the CustomCallHandler template instantiation and generates better code.

PiperOrigin-RevId: 481300474
上级 2b124645
......@@ -68,13 +68,27 @@ se::KernelBase* GpuExecutableKernelsCache::Set(
// Define the kernel launch custom call.
//===----------------------------------------------------------------------===//
static absl::Status LaunchFunc(
namespace {
struct KernelLaunch {
LLVM_ATTRIBUTE_ALWAYS_INLINE
absl::Status operator()(
const ServiceExecutableRunOptions* run_options, const std::string* ptx,
const std::vector<uint8_t>* cubin, se::DeviceMemoryBase* temp_buffer,
GpuExecutableKernelsCache* kernels_cache, int32_t grid_size_x,
int32_t grid_size_y, int32_t grid_size_z, int32_t block_size_x,
int32_t block_size_y, int32_t block_size_z,
CustomCall::RemainingArgs args, std::string_view name) const;
static KernelLaunch Handler() { return KernelLaunch(); }
};
} // namespace
absl::Status KernelLaunch::operator()(
const ServiceExecutableRunOptions* run_options, const std::string* ptx,
const std::vector<uint8_t>* cubin, se::DeviceMemoryBase* temp_buffer,
GpuExecutableKernelsCache* kernels_cache, int32_t grid_size_x,
int32_t grid_size_y, int32_t grid_size_z, int32_t block_size_x,
int32_t block_size_y, int32_t block_size_z, CustomCall::RemainingArgs args,
std::string_view name) {
std::string_view name) const {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
......@@ -147,7 +161,7 @@ static bool Launch(runtime::ExecutionContext* ctx, void** args, void** attrs,
.Arg<int32_t>() // block_size_x
.RemainingArgs() // args
.Attr<std::string_view>("kernel")
.To<checks>(LaunchFunc)
.To<checks>(KernelLaunch::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs, rets));
......
......@@ -33,19 +33,36 @@ using ::xla::runtime::HloTrace;
using ::tsl::profiler::ScopedAnnotationStack;
static absl::StatusOr<int64_t> ActivityStart(runtime::HloTrace annotation) {
return ScopedAnnotationStack::ActivityStart([&] {
// We use the same tracing annotation scheme as the ThunkSequence (see
// implementation of `GetThunkInfo` in `ir_emitter_unnested.cc`).
return absl::StrFormat("Thunk:#hlo_op=%s,hlo_module=%s,program_id=%d#",
annotation.hlo_op, annotation.module,
annotation.program_id);
});
}
//===----------------------------------------------------------------------===//
static absl::Status ActivityEnd(int64_t activity_id) {
return absl::OkStatus();
}
namespace {
struct ActivityStart {
LLVM_ATTRIBUTE_ALWAYS_INLINE
absl::StatusOr<int64_t> operator()(runtime::HloTrace annotation) const {
return ScopedAnnotationStack::ActivityStart([&] {
// We use the same tracing annotation scheme as the ThunkSequence (see
// implementation of `GetThunkInfo` in `ir_emitter_unnested.cc`).
return absl::StrFormat("Thunk:#hlo_op=%s,hlo_module=%s,program_id=%d#",
annotation.hlo_op, annotation.module,
annotation.program_id);
});
}
static ActivityStart Handler() { return ActivityStart(); }
};
struct ActivityEnd {
LLVM_ATTRIBUTE_ALWAYS_INLINE
absl::Status operator()(int64_t activity_id) const {
ScopedAnnotationStack::ActivityEnd(activity_id);
return absl::OkStatus();
}
static ActivityEnd Handler() { return ActivityEnd(); }
};
} // namespace
//===----------------------------------------------------------------------===//
......@@ -54,7 +71,7 @@ static bool Start(runtime::ExecutionContext* ctx, void** args, void** attrs,
static auto* handler = CustomCall::Bind("xla.trace.activity_start")
.Attr<HloTrace>("annotation")
.Ret<int64_t>()
.To<checks>(ActivityStart)
.To<checks>(ActivityStart::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs, rets));
......@@ -64,7 +81,7 @@ static bool End(runtime::ExecutionContext* ctx, void** args, void** attrs,
void** rets) {
static auto* handler = CustomCall::Bind("xla.trace.activity_end")
.Arg<int64_t>()
.To<checks>(ActivityEnd)
.To<checks>(ActivityEnd::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs, rets));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册