提交 279f26c7 编写于 作者: E Eugene Zhulenev 提交者: TensorFlower Gardener

[xla:gpu] Add time based and OOM cuda graph eviction policy

PiperOrigin-RevId: 549373738
上级 0d5173b0
...@@ -109,6 +109,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { ...@@ -109,6 +109,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_persistent_temp_buffers(false); opts.set_xla_gpu_enable_persistent_temp_buffers(false);
opts.set_xla_gpu_cuda_graph_min_graph_size(5); opts.set_xla_gpu_cuda_graph_min_graph_size(5);
opts.set_xla_gpu_cuda_graph_enable_concurrent_region(false); opts.set_xla_gpu_cuda_graph_enable_concurrent_region(false);
opts.set_xla_gpu_cuda_graph_eviction_timeout_seconds(60);
// Despite the name, fast min/max on GPUs does not seem to be any faster, and // Despite the name, fast min/max on GPUs does not seem to be any faster, and
// adds very counter-intuitive "NaN-swallowing" behavior. // adds very counter-intuitive "NaN-swallowing" behavior.
...@@ -905,7 +906,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list, ...@@ -905,7 +906,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Instantiate a cuda graph after the time a captured function is executed " "Instantiate a cuda graph after the time a captured function is executed "
"reaches the threshold.")); "reaches the threshold."));
flag_list->push_back(tsl::Flag( flag_list->push_back(tsl::Flag(
"xla_gpu_cuda_graph_capture_threshold", "xla_gpu_cuda_graph_min_graph_size",
int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_min_graph_size), int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_min_graph_size),
debug_options->xla_gpu_cuda_graph_min_graph_size(), debug_options->xla_gpu_cuda_graph_min_graph_size(),
"Capture a region as a function to be launched as cuda graph if the " "Capture a region as a function to be launched as cuda graph if the "
...@@ -917,6 +918,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list, ...@@ -917,6 +918,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_cuda_graph_enable_concurrent_region(), debug_options->xla_gpu_cuda_graph_enable_concurrent_region(),
"Identify concurrent regions in cuda graphs and execute them " "Identify concurrent regions in cuda graphs and execute them "
"concurrently.")); "concurrently."));
flag_list->push_back(tsl::Flag(
"xla_gpu_cuda_graph_eviction_timeout_seconds",
int32_setter_for(
&DebugOptions::set_xla_gpu_cuda_graph_eviction_timeout_seconds),
debug_options->xla_gpu_cuda_graph_eviction_timeout_seconds(),
"Timeout in seconds to evict instantiated Gpu graphs from device. When "
"XLA instantiates new Gpu graphs, it evicts graphs that were not "
"recently executed to free space on device."));
flag_list->push_back(tsl::Flag( flag_list->push_back(tsl::Flag(
"xla_gpu_enable_persistent_temp_buffers", "xla_gpu_enable_persistent_temp_buffers",
......
...@@ -386,8 +386,11 @@ Status GpuRuntimeExecutable::Execute( ...@@ -386,8 +386,11 @@ Status GpuRuntimeExecutable::Execute(
conv_runners_(executor)->snapshot(); conv_runners_(executor)->snapshot();
#if GOOGLE_CUDA #if GOOGLE_CUDA
std::shared_ptr<StreamExecutorGraphInstances> executor_graphs =
graph_instances_(executor);
StreamExecutorGraphInstances::Snapshot graph_instances = StreamExecutorGraphInstances::Snapshot graph_instances =
graph_instances_(executor)->snapshot(); executor_graphs->snapshot();
CapturedFunctionExecutionCount::Snapshot execution_count = CapturedFunctionExecutionCount::Snapshot execution_count =
captured_function_counts_(executor)->snapshot(); captured_function_counts_(executor)->snapshot();
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
...@@ -451,7 +454,8 @@ Status GpuRuntimeExecutable::Execute( ...@@ -451,7 +454,8 @@ Status GpuRuntimeExecutable::Execute(
} }
if (auto instantiated = graph_instances_.InstantiateAllGraphs( if (auto instantiated = graph_instances_.InstantiateAllGraphs(
run_options, executable, user_data, device_ptr); run_options, executable, user_data, device_ptr,
debug_options_.xla_gpu_cuda_graph_eviction_timeout_seconds());
!instantiated.ok()) { !instantiated.ok()) {
return InternalError("Failed to instantiate CUDA graphs: %s", return InternalError("Failed to instantiate CUDA graphs: %s",
instantiated.message()); instantiated.message());
......
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h" #include "tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h"
#include <algorithm>
#include <array> #include <array>
#include <atomic> #include <atomic>
#include <cstddef> #include <cstddef>
...@@ -73,15 +74,39 @@ static absl::StatusOr<OwnedCudaGraph> CaptureGraph( ...@@ -73,15 +74,39 @@ static absl::StatusOr<OwnedCudaGraph> CaptureGraph(
// CUDA graphs caching. // CUDA graphs caching.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static absl::Mutex* GetGraphInstancesMutex() { struct GraphInstances::Impl {
static auto* mu = new absl::Mutex(); struct State {
return mu; // A flag signalling if `InstantiateAllGraphs` was already called and we
} // have all Gpu graph instantiated ahead of time.
bool instantiated = false;
// Last time graph instances were used by a particular stream executor.
uint64_t last_use_micros = 0;
std::shared_ptr<StreamExecutorGraphInstances> instances =
std::make_shared<StreamExecutorGraphInstances>();
};
// XLA module name that owns graph instances. We use it only to produce logs
// that can be attributed back to XLA executables.
std::string module_name;
// Number of graphs in the parent module.
int64_t num_graphs = 0;
mutable absl::Mutex mu;
absl::node_hash_map<se::StreamExecutor*, State> graphs ABSL_GUARDED_BY(mu);
};
// Keep track of instantiated graphs on each StreamExecutor, we use this // Keep track of instantiated graphs on each StreamExecutor, we use this
// information in the graph eviction policy. // information in the graph eviction policy.
using GraphInstancesState = absl::flat_hash_map<se::StreamExecutor*, int64_t>; using GraphInstancesState = absl::flat_hash_map<se::StreamExecutor*, int64_t>;
static absl::Mutex* GetGraphInstancesStateMutex() {
static auto* mu = new absl::Mutex();
return mu;
}
static GraphInstancesState& GetGraphInstancesState() { static GraphInstancesState& GetGraphInstancesState() {
static auto* state = new GraphInstancesState(); static auto* state = new GraphInstancesState();
return *state; return *state;
...@@ -89,38 +114,121 @@ static GraphInstancesState& GetGraphInstancesState() { ...@@ -89,38 +114,121 @@ static GraphInstancesState& GetGraphInstancesState() {
static int64_t NotifyGraphInstancesCreated(se::StreamExecutor* executor, static int64_t NotifyGraphInstancesCreated(se::StreamExecutor* executor,
int64_t num_graphs) { int64_t num_graphs) {
absl::MutexLock lock(GetGraphInstancesMutex()); absl::MutexLock lock(GetGraphInstancesStateMutex());
return GetGraphInstancesState()[executor] += num_graphs; return GetGraphInstancesState()[executor] += num_graphs;
} }
static int64_t NotifyGraphInstancesDestroyed(se::StreamExecutor* executor, static int64_t NotifyGraphInstancesDestroyed(se::StreamExecutor* executor,
int64_t num_graphs) { int64_t num_graphs) {
absl::MutexLock lock(GetGraphInstancesMutex()); absl::MutexLock lock(GetGraphInstancesStateMutex());
return GetGraphInstancesState()[executor] -= num_graphs; return GetGraphInstancesState()[executor] -= num_graphs;
} }
// We keep track of all graph instances in the process, to implement graph
// eviction on OOM. Graph instances owned by GpuExecutable, so we rely on
// weak ptr to check if they are still alive.
using GraphInstancesVec = std::vector<std::weak_ptr<GraphInstances::Impl>>;
static absl::Mutex* GetGraphInstancesVecMutex() {
static auto* mu = new absl::Mutex();
return mu;
}
static GraphInstancesVec& GetGraphInstancesVec() {
static auto* vec = new GraphInstancesVec();
return *vec;
}
static void AddGraphInstances(std::weak_ptr<GraphInstances::Impl> impl) {
absl::MutexLock lock(GetGraphInstancesVecMutex());
GetGraphInstancesVec().push_back(std::move(impl));
}
// Evicts all graphs for a given executor in the current process.
static void EvictAllGraphs(
se::StreamExecutor* executor,
std::optional<uint64_t> eviction_timeout_seconds = std::nullopt) {
LOG(WARNING) << "Evict "
<< (eviction_timeout_seconds.has_value() ? "timed out" : "all")
<< " gpu graphs from executor " << executor;
TraceMe trace_instantiation([&] {
return TraceMeEncode("cuda.graph.evict_all_graphs",
{{"device_ordinal", executor->device_ordinal()}});
});
absl::MutexLock lock(GetGraphInstancesVecMutex());
auto& vec = GetGraphInstancesVec();
// Erase all expired graph instances.
vec.erase(std::remove_if(vec.begin(), vec.end(),
[](auto& weak_ptr) { return weak_ptr.expired(); }),
vec.end());
auto timed_out = [&](GraphInstances::Impl::State& state) -> bool {
auto diff = tsl::Env::Default()->NowMicros() - state.last_use_micros;
return (diff / (1000 * 1000)) > *eviction_timeout_seconds;
};
for (auto& weak_ptr : vec) {
auto ptr = weak_ptr.lock();
if (!ptr) continue;
if (!ptr->mu.TryLock()) continue;
auto it = ptr->graphs.find(executor);
if (it == ptr->graphs.end()) {
ptr->mu.Unlock();
continue;
}
// If we have a timeout value, than check it first, otherwise always evict
// graphs for a given executor.
bool is_timed_out = timed_out(it->second);
if (eviction_timeout_seconds.has_value() && !is_timed_out) {
ptr->mu.Unlock();
continue;
}
if (ptr->num_graphs > 0) {
VLOG(3) << "Evict " << ptr->num_graphs << " graphs for: @"
<< ptr->module_name << " at executor: " << executor
<< " (timed_out = " << is_timed_out << ")."
<< " Total remaining graphs at given executor: "
<< NotifyGraphInstancesDestroyed(executor, ptr->num_graphs);
}
ptr->graphs.erase(it);
ptr->mu.Unlock();
}
}
GraphInstances::GraphInstances(std::string module_name, int64_t num_graphs) GraphInstances::GraphInstances(std::string module_name, int64_t num_graphs)
: impl_(std::make_shared<Impl>()) { : impl_(std::make_shared<Impl>()) {
impl_->module_name = std::move(module_name); impl_->module_name = std::move(module_name);
impl_->num_graphs = num_graphs; impl_->num_graphs = num_graphs;
VLOG(3) << "Construct graph instances cache for: @" << impl_->module_name if (impl_->num_graphs > 0) {
<< " (num_graphs = " << impl_->num_graphs << ")"; VLOG(3) << "Construct graph instances cache for: @" << impl_->module_name
<< " (num_graphs = " << impl_->num_graphs << ")";
}
AddGraphInstances(impl_);
} }
GraphInstances::~GraphInstances() { GraphInstances::~GraphInstances() {
VLOG(3) << "Destroy graph instances cache for: @" << impl_->module_name if (impl_->num_graphs > 0) {
<< " (num_graphs = " << impl_->num_graphs << ")"; VLOG(3) << "Destroy graph instances cache for: @" << impl_->module_name
<< " (num_graphs = " << impl_->num_graphs << ")";
absl::MutexLock lock(&impl_->mu);
for (auto& [executor, state] : impl_->graphs) { absl::MutexLock lock(&impl_->mu);
VLOG(3) << "Destroy " << impl_->num_graphs << " graphs for: @" for (auto& [executor, state] : impl_->graphs) {
<< impl_->module_name << " at executor: " << executor VLOG(3) << "Destroy " << impl_->num_graphs << " graphs for: @"
<< ". Total remaining graphs at given executor: " << impl_->module_name << " at executor: " << executor
<< NotifyGraphInstancesDestroyed(executor, impl_->num_graphs); << ". Total remaining graphs at given executor: "
<< NotifyGraphInstancesDestroyed(executor, impl_->num_graphs);
}
} }
} }
StreamExecutorGraphInstances* GraphInstances::operator()( std::shared_ptr<StreamExecutorGraphInstances> GraphInstances::operator()(
se::StreamExecutor* executor) { se::StreamExecutor* executor) {
absl::MutexLock lock(&impl_->mu); absl::MutexLock lock(&impl_->mu);
...@@ -132,9 +240,9 @@ StreamExecutorGraphInstances* GraphInstances::operator()( ...@@ -132,9 +240,9 @@ StreamExecutorGraphInstances* GraphInstances::operator()(
<< NotifyGraphInstancesCreated(executor, impl_->num_graphs); << NotifyGraphInstancesCreated(executor, impl_->num_graphs);
} }
State& state = it.first->second; Impl::State& state = it.first->second;
state.last_use_micros = tsl::Env::Default()->NowMicros(); state.last_use_micros = tsl::Env::Default()->NowMicros();
return &state.instances; return state.instances;
} }
bool GraphInstances::InstantiatedAllGraphs( bool GraphInstances::InstantiatedAllGraphs(
...@@ -149,22 +257,29 @@ bool GraphInstances::InstantiatedAllGraphs( ...@@ -149,22 +257,29 @@ bool GraphInstances::InstantiatedAllGraphs(
Status GraphInstances::InstantiateAllGraphs( Status GraphInstances::InstantiateAllGraphs(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
const Executable& executable, const CustomCall::UserData& user_data, const Executable& executable, const CustomCall::UserData& user_data,
void* ptr) { void* ptr, std::optional<uint64_t> eviction_timeout_seconds) {
// We have only "main" function in the executable. // We have only "main" function in the executable.
if (executable.num_functions() == 1) return OkStatus(); if (executable.num_functions() == 1) return OkStatus();
absl::MutexLock lock(&impl_->mu); absl::MutexLock lock(&impl_->mu);
se::StreamExecutor* executor = run_options->stream()->parent(); se::StreamExecutor* executor = run_options->stream()->parent();
State& state = impl_->graphs[executor]; Impl::State& state = impl_->graphs[executor];
// All Gpu graphs are already instantiated for a given executor. // All Gpu graphs are already instantiated for a given executor.
if (state.instantiated) return OkStatus(); if (state.instantiated) return OkStatus();
TraceMe trace("cuda.graph.instantiate_all"); TraceMe trace("cuda.graph.instantiate_all");
// Initialize graph instances snapshot for a given executor. // Evict all timeout graphs before trying to instantiate new ones.
StreamExecutorGraphInstances::Snapshot instances = state.instances.snapshot(); EvictAllGraphs(executor, eviction_timeout_seconds);
// We'll retry graph instantiation on OOM errors after evicting all graphs
// instantiated on `executor`.
int32_t num_retries = 0;
StreamExecutorGraphInstances::Snapshot instances =
state.instances->snapshot();
// Instantiate all Gpu graphs by calling graph capture functions with fake // Instantiate all Gpu graphs by calling graph capture functions with fake
// arguments. Once we'll execute them first time for real, they'll be updated // arguments. Once we'll execute them first time for real, they'll be updated
...@@ -217,9 +332,19 @@ Status GraphInstances::InstantiateAllGraphs( ...@@ -217,9 +332,19 @@ Status GraphInstances::InstantiateAllGraphs(
return GraphInstance(0, std::move(e)); return GraphInstance(0, std::move(e));
}; };
TF_ASSIGN_OR_RETURN(GraphInstance * instance, absl::StatusOr<GraphInstance*> instance =
instances.GetOrCreate(ordinal, instantiate)); instances.GetOrCreate(ordinal, instantiate);
(void)instance;
// Retry on OOM error after evicting all graphs from executor.
if (instance.status().code() == absl::StatusCode::kResourceExhausted &&
num_retries++ == 0) {
EvictAllGraphs(executor);
--ordinal; // we'll try to instantiate the same graph one more time
continue;
}
// Otherwise return an error to the caller.
if (!instance.ok()) return instance.status();
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. ...@@ -18,6 +18,7 @@ limitations under the License.
#include <atomic> #include <atomic>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -87,48 +88,32 @@ class StreamExecutorGraphInstances ...@@ -87,48 +88,32 @@ class StreamExecutorGraphInstances
// end up with thousands of unused (or rarely used) graphs in device memory. // end up with thousands of unused (or rarely used) graphs in device memory.
class GraphInstances { class GraphInstances {
public: public:
struct Impl;
GraphInstances(std::string module_name, int64_t num_graphs); GraphInstances(std::string module_name, int64_t num_graphs);
~GraphInstances(); ~GraphInstances();
StreamExecutorGraphInstances* operator()(se::StreamExecutor* executor); std::shared_ptr<StreamExecutorGraphInstances> operator()(
se::StreamExecutor* executor);
// Instantiates all Gpu graphs defined by the given executable using user // Instantiates all Gpu graphs defined by the given executable using user
// provided run options. This guarantees that once we start execution, all Gpu // provided run options. This guarantees that once we start execution, all Gpu
// graphs are ready, and will only require cheap update operation and will not // graphs are ready, and will only require cheap update operation and will not
// require allocating new resources (we avoid non deterministic OOM errors). // require allocating new resources (we avoid non deterministic OOM errors).
Status InstantiateAllGraphs(const ServiceExecutableRunOptions* run_options, //
const runtime::Executable& executable, // If timeout is not nullopt it will evict all previously instantiated graphs
const runtime::CustomCall::UserData& user_data, // that were used more than `eviction_timeout_seconds` seconds ago.
void* ptr); Status InstantiateAllGraphs(
const ServiceExecutableRunOptions* run_options,
const runtime::Executable& executable,
const runtime::CustomCall::UserData& user_data, void* ptr,
std::optional<uint64_t> eviction_timeout_seconds = std::nullopt);
// Returns true if all Gpu graphs were already instantiated. // Returns true if all Gpu graphs were already instantiated.
bool InstantiatedAllGraphs(const ServiceExecutableRunOptions* run_options, bool InstantiatedAllGraphs(const ServiceExecutableRunOptions* run_options,
const runtime::Executable& executable); const runtime::Executable& executable);
private: private:
struct State {
// A flag signalling if `InstantiateAllGraphs` was already called and we
// have all Gpu graph instantiated ahead of time.
bool instantiated = false;
// Last time graph instances were used by a particular stream executor.
uint64_t last_use_micros = 0;
StreamExecutorGraphInstances instances;
};
struct Impl {
// XLA module name that owns graph instances. We use it only to produce logs
// that can be attributed back to XLA executables.
std::string module_name;
// Number of graphs in the parent module.
int64_t num_graphs;
mutable absl::Mutex mu;
absl::node_hash_map<se::StreamExecutor*, State> graphs ABSL_GUARDED_BY(mu);
};
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
}; };
......
...@@ -26,12 +26,6 @@ limitations under the License. ...@@ -26,12 +26,6 @@ limitations under the License.
namespace stream_executor { namespace stream_executor {
namespace gpu { namespace gpu {
template <typename... Args>
static tsl::Status InternalError(const absl::FormatSpec<Args...>& format,
const Args&... args) {
return tsl::errors::Internal(absl::StrFormat(format, args...));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// RAII helpers for CUDA graph types. // RAII helpers for CUDA graph types.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
...@@ -80,8 +74,8 @@ tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) { ...@@ -80,8 +74,8 @@ tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) {
auto err = cudaGraphExecUpdate(get(), graph.get(), &updated); auto err = cudaGraphExecUpdate(get(), graph.get(), &updated);
if (err != cudaSuccess || updated.result != cudaGraphExecUpdateSuccess) if (err != cudaSuccess || updated.result != cudaGraphExecUpdateSuccess)
return InternalError("failed to update cuda graph: %s", return absl::InternalError(absl::StrFormat(
cudaGetErrorString(err)); "failed to update cuda graph: %s", cudaGetErrorString(err)));
#else #else
cudaGraphExecUpdateResult updated; cudaGraphExecUpdateResult updated;
...@@ -89,8 +83,8 @@ tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) { ...@@ -89,8 +83,8 @@ tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) {
auto err = cudaGraphExecUpdate(get(), graph.get(), &error_node, &updated); auto err = cudaGraphExecUpdate(get(), graph.get(), &error_node, &updated);
if (err != cudaSuccess || updated != cudaGraphExecUpdateSuccess) if (err != cudaSuccess || updated != cudaGraphExecUpdateSuccess)
return InternalError("Failed to update cuda graph %s", return absl::InternalError(absl::StrFormat("Failed to update cuda graph %s",
cudaGetErrorString(err)); cudaGetErrorString(err)));
#endif #endif
return tsl::OkStatus(); return tsl::OkStatus();
...@@ -103,8 +97,8 @@ tsl::Status OwnedCudaGraphExec::Launch(stream_executor::Stream* stream) { ...@@ -103,8 +97,8 @@ tsl::Status OwnedCudaGraphExec::Launch(stream_executor::Stream* stream) {
if (auto err = cudaGraphLaunch(get(), AsGpuStreamValue(stream)); if (auto err = cudaGraphLaunch(get(), AsGpuStreamValue(stream));
err != cudaSuccess) err != cudaSuccess)
return InternalError("failed to run cuda graph: %s", return absl::InternalError(absl::StrFormat("failed to run cuda graph: %s",
cudaGetErrorString(err)); cudaGetErrorString(err)));
return tsl::OkStatus(); return tsl::OkStatus();
} }
...@@ -133,20 +127,20 @@ tsl::StatusOr<OwnedCudaGraph> CaptureCudaGraph( ...@@ -133,20 +127,20 @@ tsl::StatusOr<OwnedCudaGraph> CaptureCudaGraph(
// Capture graph constructed by the exported graph capture function. // Capture graph constructed by the exported graph capture function.
if (auto err = cudaStreamBeginCapture(gpu_stream, mode); err != cudaSuccess) if (auto err = cudaStreamBeginCapture(gpu_stream, mode); err != cudaSuccess)
return InternalError("stream begin capture failed: %s", return absl::InternalError(absl::StrFormat(
cudaGetErrorString(err)); "stream begin capture failed: %s", cudaGetErrorString(err)));
// Call into graph capture function. // Call into graph capture function.
auto captured = capture(); auto captured = capture();
// Always stop capturing the stream before checking `captured` result. // Always stop capturing the stream before checking `captured` result.
if (auto err = cudaStreamEndCapture(gpu_stream, &graph); err != cudaSuccess) if (auto err = cudaStreamEndCapture(gpu_stream, &graph); err != cudaSuccess)
return InternalError("stream end capture failed: %s", return absl::InternalError(absl::StrFormat("stream end capture failed: %s",
cudaGetErrorString(err)); cudaGetErrorString(err)));
if (!captured.ok()) if (!captured.ok())
return InternalError("failed to capture CUDA graph: %s", return absl::InternalError(absl::StrFormat(
captured.message()); "failed to capture CUDA graph: %s", captured.message()));
VLOG(5) << "Captured CUDA graph " << graph; VLOG(5) << "Captured CUDA graph " << graph;
...@@ -195,8 +189,16 @@ tsl::StatusOr<OwnedCudaGraphExec> InstantiateCudaGraph(OwnedCudaGraph graph) { ...@@ -195,8 +189,16 @@ tsl::StatusOr<OwnedCudaGraphExec> InstantiateCudaGraph(OwnedCudaGraph graph) {
if (auto err = cudaGraphInstantiate(&exec, &*graph, nullptr, nullptr, 0); if (auto err = cudaGraphInstantiate(&exec, &*graph, nullptr, nullptr, 0);
#endif #endif
err != cudaSuccess) { err != cudaSuccess) {
return InternalError("graph instantiation failed: %s", if (err == cudaErrorMemoryAllocation) {
cudaGetErrorString(err)); // OOM is a recoverable error, we evict all instantiated cuda graphs to
// free up some space (see graph launch.cc). Clear error status.
return absl::ResourceExhaustedError(
absl::StrFormat("graph instantiation failed: %s",
cudaGetErrorString(cudaGetLastError())));
} else {
return absl::InternalError(absl::StrFormat(
"graph instantiation failed: %s", cudaGetErrorString(err)));
}
} }
size_t id = CudaGraphSupport::NotifyGraphExecCreated(); size_t id = CudaGraphSupport::NotifyGraphExecCreated();
...@@ -211,8 +213,8 @@ tsl::StatusOr<bool> IsStreamCapturing(stream_executor::Stream* stream) { ...@@ -211,8 +213,8 @@ tsl::StatusOr<bool> IsStreamCapturing(stream_executor::Stream* stream) {
cudaError_t err = cudaStreamIsCapturing( cudaError_t err = cudaStreamIsCapturing(
stream_executor::gpu::AsGpuStreamValue(stream), &capture_status); stream_executor::gpu::AsGpuStreamValue(stream), &capture_status);
if (err != cudaSuccess) { if (err != cudaSuccess) {
return InternalError("Failed to get stream's capture status: %s", return absl::InternalError(absl::StrFormat(
cudaGetErrorString(err)); "Failed to get stream's capture status: %s", cudaGetErrorString(err)));
} }
return capture_status == cudaStreamCaptureStatusActive; return capture_status == cudaStreamCaptureStatusActive;
......
...@@ -454,6 +454,11 @@ message DebugOptions { ...@@ -454,6 +454,11 @@ message DebugOptions {
// Identify concurrent regions in cuda graphs and execute them concurrently. // Identify concurrent regions in cuda graphs and execute them concurrently.
bool xla_gpu_cuda_graph_enable_concurrent_region = 215; bool xla_gpu_cuda_graph_enable_concurrent_region = 215;
// Timeout in seconds to evict instantiated Gpu graphs from device. When XLA
// instantiates new Gpu graphs, it evicts graphs that were not recently
// executed to free space on device.
int32 xla_gpu_cuda_graph_eviction_timeout_seconds = 230;
// Allocate temp buffers once during the first execution of an executable. // Allocate temp buffers once during the first execution of an executable.
// Reuse the allocated buffers in subsequent executions. Executables cannot // Reuse the allocated buffers in subsequent executions. Executables cannot
// run concurrently if this is enabled. // run concurrently if this is enabled.
...@@ -572,7 +577,7 @@ message DebugOptions { ...@@ -572,7 +577,7 @@ message DebugOptions {
int32 xla_gpu_triton_fusion_level = 229; int32 xla_gpu_triton_fusion_level = 229;
// Next id: 230 // Next id: 231
// Extra options to pass to the compilation backend (e.g. LLVM); specific // Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend. // interpretation of these values is left to the backend.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册