提交 279f26c7 编写于 作者: E Eugene Zhulenev 提交者: TensorFlower Gardener

[xla:gpu] Add time based and OOM cuda graph eviction policy

PiperOrigin-RevId: 549373738
上级 0d5173b0
......@@ -109,6 +109,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_persistent_temp_buffers(false);
opts.set_xla_gpu_cuda_graph_min_graph_size(5);
opts.set_xla_gpu_cuda_graph_enable_concurrent_region(false);
opts.set_xla_gpu_cuda_graph_eviction_timeout_seconds(60);
// Despite the name, fast min/max on GPUs does not seem to be any faster, and
// adds very counter-intuitive "NaN-swallowing" behavior.
......@@ -905,7 +906,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Instantiate a cuda graph after the time a captured function is executed "
"reaches the threshold."));
flag_list->push_back(tsl::Flag(
"xla_gpu_cuda_graph_capture_threshold",
"xla_gpu_cuda_graph_min_graph_size",
int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_min_graph_size),
debug_options->xla_gpu_cuda_graph_min_graph_size(),
"Capture a region as a function to be launched as cuda graph if the "
......@@ -917,6 +918,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_cuda_graph_enable_concurrent_region(),
"Identify concurrent regions in cuda graphs and execute them "
"concurrently."));
flag_list->push_back(tsl::Flag(
"xla_gpu_cuda_graph_eviction_timeout_seconds",
int32_setter_for(
&DebugOptions::set_xla_gpu_cuda_graph_eviction_timeout_seconds),
debug_options->xla_gpu_cuda_graph_eviction_timeout_seconds(),
"Timeout in seconds to evict instantiated Gpu graphs from device. When "
"XLA instantiates new Gpu graphs, it evicts graphs that were not "
"recently executed to free space on device."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_persistent_temp_buffers",
......
......@@ -386,8 +386,11 @@ Status GpuRuntimeExecutable::Execute(
conv_runners_(executor)->snapshot();
#if GOOGLE_CUDA
std::shared_ptr<StreamExecutorGraphInstances> executor_graphs =
graph_instances_(executor);
StreamExecutorGraphInstances::Snapshot graph_instances =
graph_instances_(executor)->snapshot();
executor_graphs->snapshot();
CapturedFunctionExecutionCount::Snapshot execution_count =
captured_function_counts_(executor)->snapshot();
#endif // GOOGLE_CUDA
......@@ -451,7 +454,8 @@ Status GpuRuntimeExecutable::Execute(
}
if (auto instantiated = graph_instances_.InstantiateAllGraphs(
run_options, executable, user_data, device_ptr);
run_options, executable, user_data, device_ptr,
debug_options_.xla_gpu_cuda_graph_eviction_timeout_seconds());
!instantiated.ok()) {
return InternalError("Failed to instantiate CUDA graphs: %s",
instantiated.message());
......
......@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h"
#include <algorithm>
#include <array>
#include <atomic>
#include <cstddef>
......@@ -73,15 +74,39 @@ static absl::StatusOr<OwnedCudaGraph> CaptureGraph(
// CUDA graphs caching.
//===----------------------------------------------------------------------===//
static absl::Mutex* GetGraphInstancesMutex() {
static auto* mu = new absl::Mutex();
return mu;
}
struct GraphInstances::Impl {
struct State {
// A flag signalling if `InstantiateAllGraphs` was already called and we
// have all Gpu graph instantiated ahead of time.
bool instantiated = false;
// Last time graph instances were used by a particular stream executor.
uint64_t last_use_micros = 0;
std::shared_ptr<StreamExecutorGraphInstances> instances =
std::make_shared<StreamExecutorGraphInstances>();
};
// XLA module name that owns graph instances. We use it only to produce logs
// that can be attributed back to XLA executables.
std::string module_name;
// Number of graphs in the parent module.
int64_t num_graphs = 0;
mutable absl::Mutex mu;
absl::node_hash_map<se::StreamExecutor*, State> graphs ABSL_GUARDED_BY(mu);
};
// Keep track of instantiated graphs on each StreamExecutor, we use this
// information in the graph eviction policy.
using GraphInstancesState = absl::flat_hash_map<se::StreamExecutor*, int64_t>;
static absl::Mutex* GetGraphInstancesStateMutex() {
static auto* mu = new absl::Mutex();
return mu;
}
static GraphInstancesState& GetGraphInstancesState() {
static auto* state = new GraphInstancesState();
return *state;
......@@ -89,38 +114,121 @@ static GraphInstancesState& GetGraphInstancesState() {
static int64_t NotifyGraphInstancesCreated(se::StreamExecutor* executor,
int64_t num_graphs) {
absl::MutexLock lock(GetGraphInstancesMutex());
absl::MutexLock lock(GetGraphInstancesStateMutex());
return GetGraphInstancesState()[executor] += num_graphs;
}
static int64_t NotifyGraphInstancesDestroyed(se::StreamExecutor* executor,
int64_t num_graphs) {
absl::MutexLock lock(GetGraphInstancesMutex());
absl::MutexLock lock(GetGraphInstancesStateMutex());
return GetGraphInstancesState()[executor] -= num_graphs;
}
// We keep track of all graph instances in the process, to implement graph
// eviction on OOM. Graph instances owned by GpuExecutable, so we rely on
// weak ptr to check if they are still alive.
using GraphInstancesVec = std::vector<std::weak_ptr<GraphInstances::Impl>>;
static absl::Mutex* GetGraphInstancesVecMutex() {
static auto* mu = new absl::Mutex();
return mu;
}
static GraphInstancesVec& GetGraphInstancesVec() {
static auto* vec = new GraphInstancesVec();
return *vec;
}
static void AddGraphInstances(std::weak_ptr<GraphInstances::Impl> impl) {
absl::MutexLock lock(GetGraphInstancesVecMutex());
GetGraphInstancesVec().push_back(std::move(impl));
}
// Evicts all graphs for a given executor in the current process.
static void EvictAllGraphs(
se::StreamExecutor* executor,
std::optional<uint64_t> eviction_timeout_seconds = std::nullopt) {
LOG(WARNING) << "Evict "
<< (eviction_timeout_seconds.has_value() ? "timed out" : "all")
<< " gpu graphs from executor " << executor;
TraceMe trace_instantiation([&] {
return TraceMeEncode("cuda.graph.evict_all_graphs",
{{"device_ordinal", executor->device_ordinal()}});
});
absl::MutexLock lock(GetGraphInstancesVecMutex());
auto& vec = GetGraphInstancesVec();
// Erase all expired graph instances.
vec.erase(std::remove_if(vec.begin(), vec.end(),
[](auto& weak_ptr) { return weak_ptr.expired(); }),
vec.end());
auto timed_out = [&](GraphInstances::Impl::State& state) -> bool {
auto diff = tsl::Env::Default()->NowMicros() - state.last_use_micros;
return (diff / (1000 * 1000)) > *eviction_timeout_seconds;
};
for (auto& weak_ptr : vec) {
auto ptr = weak_ptr.lock();
if (!ptr) continue;
if (!ptr->mu.TryLock()) continue;
auto it = ptr->graphs.find(executor);
if (it == ptr->graphs.end()) {
ptr->mu.Unlock();
continue;
}
// If we have a timeout value, than check it first, otherwise always evict
// graphs for a given executor.
bool is_timed_out = timed_out(it->second);
if (eviction_timeout_seconds.has_value() && !is_timed_out) {
ptr->mu.Unlock();
continue;
}
if (ptr->num_graphs > 0) {
VLOG(3) << "Evict " << ptr->num_graphs << " graphs for: @"
<< ptr->module_name << " at executor: " << executor
<< " (timed_out = " << is_timed_out << ")."
<< " Total remaining graphs at given executor: "
<< NotifyGraphInstancesDestroyed(executor, ptr->num_graphs);
}
ptr->graphs.erase(it);
ptr->mu.Unlock();
}
}
GraphInstances::GraphInstances(std::string module_name, int64_t num_graphs)
: impl_(std::make_shared<Impl>()) {
impl_->module_name = std::move(module_name);
impl_->num_graphs = num_graphs;
VLOG(3) << "Construct graph instances cache for: @" << impl_->module_name
<< " (num_graphs = " << impl_->num_graphs << ")";
if (impl_->num_graphs > 0) {
VLOG(3) << "Construct graph instances cache for: @" << impl_->module_name
<< " (num_graphs = " << impl_->num_graphs << ")";
}
AddGraphInstances(impl_);
}
GraphInstances::~GraphInstances() {
VLOG(3) << "Destroy graph instances cache for: @" << impl_->module_name
<< " (num_graphs = " << impl_->num_graphs << ")";
absl::MutexLock lock(&impl_->mu);
for (auto& [executor, state] : impl_->graphs) {
VLOG(3) << "Destroy " << impl_->num_graphs << " graphs for: @"
<< impl_->module_name << " at executor: " << executor
<< ". Total remaining graphs at given executor: "
<< NotifyGraphInstancesDestroyed(executor, impl_->num_graphs);
if (impl_->num_graphs > 0) {
VLOG(3) << "Destroy graph instances cache for: @" << impl_->module_name
<< " (num_graphs = " << impl_->num_graphs << ")";
absl::MutexLock lock(&impl_->mu);
for (auto& [executor, state] : impl_->graphs) {
VLOG(3) << "Destroy " << impl_->num_graphs << " graphs for: @"
<< impl_->module_name << " at executor: " << executor
<< ". Total remaining graphs at given executor: "
<< NotifyGraphInstancesDestroyed(executor, impl_->num_graphs);
}
}
}
StreamExecutorGraphInstances* GraphInstances::operator()(
std::shared_ptr<StreamExecutorGraphInstances> GraphInstances::operator()(
se::StreamExecutor* executor) {
absl::MutexLock lock(&impl_->mu);
......@@ -132,9 +240,9 @@ StreamExecutorGraphInstances* GraphInstances::operator()(
<< NotifyGraphInstancesCreated(executor, impl_->num_graphs);
}
State& state = it.first->second;
Impl::State& state = it.first->second;
state.last_use_micros = tsl::Env::Default()->NowMicros();
return &state.instances;
return state.instances;
}
bool GraphInstances::InstantiatedAllGraphs(
......@@ -149,22 +257,29 @@ bool GraphInstances::InstantiatedAllGraphs(
Status GraphInstances::InstantiateAllGraphs(
const ServiceExecutableRunOptions* run_options,
const Executable& executable, const CustomCall::UserData& user_data,
void* ptr) {
void* ptr, std::optional<uint64_t> eviction_timeout_seconds) {
// We have only "main" function in the executable.
if (executable.num_functions() == 1) return OkStatus();
absl::MutexLock lock(&impl_->mu);
se::StreamExecutor* executor = run_options->stream()->parent();
State& state = impl_->graphs[executor];
Impl::State& state = impl_->graphs[executor];
// All Gpu graphs are already instantiated for a given executor.
if (state.instantiated) return OkStatus();
TraceMe trace("cuda.graph.instantiate_all");
// Initialize graph instances snapshot for a given executor.
StreamExecutorGraphInstances::Snapshot instances = state.instances.snapshot();
// Evict all timeout graphs before trying to instantiate new ones.
EvictAllGraphs(executor, eviction_timeout_seconds);
// We'll retry graph instantiation on OOM errors after evicting all graphs
// instantiated on `executor`.
int32_t num_retries = 0;
StreamExecutorGraphInstances::Snapshot instances =
state.instances->snapshot();
// Instantiate all Gpu graphs by calling graph capture functions with fake
// arguments. Once we'll execute them first time for real, they'll be updated
......@@ -217,9 +332,19 @@ Status GraphInstances::InstantiateAllGraphs(
return GraphInstance(0, std::move(e));
};
TF_ASSIGN_OR_RETURN(GraphInstance * instance,
instances.GetOrCreate(ordinal, instantiate));
(void)instance;
absl::StatusOr<GraphInstance*> instance =
instances.GetOrCreate(ordinal, instantiate);
// Retry on OOM error after evicting all graphs from executor.
if (instance.status().code() == absl::StatusCode::kResourceExhausted &&
num_retries++ == 0) {
EvictAllGraphs(executor);
--ordinal; // we'll try to instantiate the same graph one more time
continue;
}
// Otherwise return an error to the caller.
if (!instance.ok()) return instance.status();
#endif // GOOGLE_CUDA
}
......
......@@ -18,6 +18,7 @@ limitations under the License.
#include <atomic>
#include <memory>
#include <optional>
#include <string>
#include <utility>
......@@ -87,48 +88,32 @@ class StreamExecutorGraphInstances
// end up with thousands of unused (or rarely used) graphs in device memory.
class GraphInstances {
public:
struct Impl;
GraphInstances(std::string module_name, int64_t num_graphs);
~GraphInstances();
StreamExecutorGraphInstances* operator()(se::StreamExecutor* executor);
std::shared_ptr<StreamExecutorGraphInstances> operator()(
se::StreamExecutor* executor);
// Instantiates all Gpu graphs defined by the given executable using user
// provided run options. This guarantees that once we start execution, all Gpu
// graphs are ready, and will only require cheap update operation and will not
// require allocating new resources (we avoid non deterministic OOM errors).
Status InstantiateAllGraphs(const ServiceExecutableRunOptions* run_options,
const runtime::Executable& executable,
const runtime::CustomCall::UserData& user_data,
void* ptr);
//
// If timeout is not nullopt it will evict all previously instantiated graphs
// that were used more than `eviction_timeout_seconds` seconds ago.
Status InstantiateAllGraphs(
const ServiceExecutableRunOptions* run_options,
const runtime::Executable& executable,
const runtime::CustomCall::UserData& user_data, void* ptr,
std::optional<uint64_t> eviction_timeout_seconds = std::nullopt);
// Returns true if all Gpu graphs were already instantiated.
bool InstantiatedAllGraphs(const ServiceExecutableRunOptions* run_options,
const runtime::Executable& executable);
private:
struct State {
// A flag signalling if `InstantiateAllGraphs` was already called and we
// have all Gpu graph instantiated ahead of time.
bool instantiated = false;
// Last time graph instances were used by a particular stream executor.
uint64_t last_use_micros = 0;
StreamExecutorGraphInstances instances;
};
struct Impl {
// XLA module name that owns graph instances. We use it only to produce logs
// that can be attributed back to XLA executables.
std::string module_name;
// Number of graphs in the parent module.
int64_t num_graphs;
mutable absl::Mutex mu;
absl::node_hash_map<se::StreamExecutor*, State> graphs ABSL_GUARDED_BY(mu);
};
std::shared_ptr<Impl> impl_;
};
......
......@@ -26,12 +26,6 @@ limitations under the License.
namespace stream_executor {
namespace gpu {
template <typename... Args>
static tsl::Status InternalError(const absl::FormatSpec<Args...>& format,
const Args&... args) {
return tsl::errors::Internal(absl::StrFormat(format, args...));
}
//===----------------------------------------------------------------------===//
// RAII helpers for CUDA graph types.
//===----------------------------------------------------------------------===//
......@@ -80,8 +74,8 @@ tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) {
auto err = cudaGraphExecUpdate(get(), graph.get(), &updated);
if (err != cudaSuccess || updated.result != cudaGraphExecUpdateSuccess)
return InternalError("failed to update cuda graph: %s",
cudaGetErrorString(err));
return absl::InternalError(absl::StrFormat(
"failed to update cuda graph: %s", cudaGetErrorString(err)));
#else
cudaGraphExecUpdateResult updated;
......@@ -89,8 +83,8 @@ tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) {
auto err = cudaGraphExecUpdate(get(), graph.get(), &error_node, &updated);
if (err != cudaSuccess || updated != cudaGraphExecUpdateSuccess)
return InternalError("Failed to update cuda graph %s",
cudaGetErrorString(err));
return absl::InternalError(absl::StrFormat("Failed to update cuda graph %s",
cudaGetErrorString(err)));
#endif
return tsl::OkStatus();
......@@ -103,8 +97,8 @@ tsl::Status OwnedCudaGraphExec::Launch(stream_executor::Stream* stream) {
if (auto err = cudaGraphLaunch(get(), AsGpuStreamValue(stream));
err != cudaSuccess)
return InternalError("failed to run cuda graph: %s",
cudaGetErrorString(err));
return absl::InternalError(absl::StrFormat("failed to run cuda graph: %s",
cudaGetErrorString(err)));
return tsl::OkStatus();
}
......@@ -133,20 +127,20 @@ tsl::StatusOr<OwnedCudaGraph> CaptureCudaGraph(
// Capture graph constructed by the exported graph capture function.
if (auto err = cudaStreamBeginCapture(gpu_stream, mode); err != cudaSuccess)
return InternalError("stream begin capture failed: %s",
cudaGetErrorString(err));
return absl::InternalError(absl::StrFormat(
"stream begin capture failed: %s", cudaGetErrorString(err)));
// Call into graph capture function.
auto captured = capture();
// Always stop capturing the stream before checking `captured` result.
if (auto err = cudaStreamEndCapture(gpu_stream, &graph); err != cudaSuccess)
return InternalError("stream end capture failed: %s",
cudaGetErrorString(err));
return absl::InternalError(absl::StrFormat("stream end capture failed: %s",
cudaGetErrorString(err)));
if (!captured.ok())
return InternalError("failed to capture CUDA graph: %s",
captured.message());
return absl::InternalError(absl::StrFormat(
"failed to capture CUDA graph: %s", captured.message()));
VLOG(5) << "Captured CUDA graph " << graph;
......@@ -195,8 +189,16 @@ tsl::StatusOr<OwnedCudaGraphExec> InstantiateCudaGraph(OwnedCudaGraph graph) {
if (auto err = cudaGraphInstantiate(&exec, &*graph, nullptr, nullptr, 0);
#endif
err != cudaSuccess) {
return InternalError("graph instantiation failed: %s",
cudaGetErrorString(err));
if (err == cudaErrorMemoryAllocation) {
// OOM is a recoverable error, we evict all instantiated cuda graphs to
// free up some space (see graph launch.cc). Clear error status.
return absl::ResourceExhaustedError(
absl::StrFormat("graph instantiation failed: %s",
cudaGetErrorString(cudaGetLastError())));
} else {
return absl::InternalError(absl::StrFormat(
"graph instantiation failed: %s", cudaGetErrorString(err)));
}
}
size_t id = CudaGraphSupport::NotifyGraphExecCreated();
......@@ -211,8 +213,8 @@ tsl::StatusOr<bool> IsStreamCapturing(stream_executor::Stream* stream) {
cudaError_t err = cudaStreamIsCapturing(
stream_executor::gpu::AsGpuStreamValue(stream), &capture_status);
if (err != cudaSuccess) {
return InternalError("Failed to get stream's capture status: %s",
cudaGetErrorString(err));
return absl::InternalError(absl::StrFormat(
"Failed to get stream's capture status: %s", cudaGetErrorString(err)));
}
return capture_status == cudaStreamCaptureStatusActive;
......
......@@ -454,6 +454,11 @@ message DebugOptions {
// Identify concurrent regions in cuda graphs and execute them concurrently.
bool xla_gpu_cuda_graph_enable_concurrent_region = 215;
// Timeout in seconds to evict instantiated Gpu graphs from device. When XLA
// instantiates new Gpu graphs, it evicts graphs that were not recently
// executed to free space on device.
int32 xla_gpu_cuda_graph_eviction_timeout_seconds = 230;
// Allocate temp buffers once during the first execution of an executable.
// Reuse the allocated buffers in subsequent executions. Executables cannot
// run concurrently if this is enabled.
......@@ -572,7 +577,7 @@ message DebugOptions {
int32 xla_gpu_triton_fusion_level = 229;
// Next id: 230
// Next id: 231
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册