Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Tensorflow
提交
279f26c7
T
Tensorflow
项目概览
曾经的那一瞬间
/
Tensorflow
10 个月 前同步成功
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
279f26c7
编写于
7月 19, 2023
作者:
E
Eugene Zhulenev
提交者:
TensorFlower Gardener
7月 19, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[xla:gpu] Add time based and OOM cuda graph eviction policy
PiperOrigin-RevId: 549373738
上级
0d5173b0
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
211 addition
and
81 deletion
+211
-81
tensorflow/compiler/xla/debug_options_flags.cc
tensorflow/compiler/xla/debug_options_flags.cc
+10
-1
tensorflow/compiler/xla/service/gpu/runtime/executable.cc
tensorflow/compiler/xla/service/gpu/runtime/executable.cc
+6
-2
tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc
tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc
+152
-27
tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h
tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h
+13
-28
tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc
tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc
+24
-22
tensorflow/compiler/xla/xla.proto
tensorflow/compiler/xla/xla.proto
+6
-1
未找到文件。
tensorflow/compiler/xla/debug_options_flags.cc
浏览文件 @
279f26c7
...
...
@@ -109,6 +109,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts
.
set_xla_gpu_enable_persistent_temp_buffers
(
false
);
opts
.
set_xla_gpu_cuda_graph_min_graph_size
(
5
);
opts
.
set_xla_gpu_cuda_graph_enable_concurrent_region
(
false
);
opts
.
set_xla_gpu_cuda_graph_eviction_timeout_seconds
(
60
);
// Despite the name, fast min/max on GPUs does not seem to be any faster, and
// adds very counter-intuitive "NaN-swallowing" behavior.
...
...
@@ -905,7 +906,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Instantiate a cuda graph after the time a captured function is executed "
"reaches the threshold."
));
flag_list
->
push_back
(
tsl
::
Flag
(
"xla_gpu_cuda_graph_
capture_threshold
"
,
"xla_gpu_cuda_graph_
min_graph_size
"
,
int32_setter_for
(
&
DebugOptions
::
set_xla_gpu_cuda_graph_min_graph_size
),
debug_options
->
xla_gpu_cuda_graph_min_graph_size
(),
"Capture a region as a function to be launched as cuda graph if the "
...
...
@@ -917,6 +918,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options
->
xla_gpu_cuda_graph_enable_concurrent_region
(),
"Identify concurrent regions in cuda graphs and execute them "
"concurrently."
));
flag_list
->
push_back
(
tsl
::
Flag
(
"xla_gpu_cuda_graph_eviction_timeout_seconds"
,
int32_setter_for
(
&
DebugOptions
::
set_xla_gpu_cuda_graph_eviction_timeout_seconds
),
debug_options
->
xla_gpu_cuda_graph_eviction_timeout_seconds
(),
"Timeout in seconds to evict instantiated Gpu graphs from device. When "
"XLA instantiates new Gpu graphs, it evicts graphs that were not "
"recently executed to free space on device."
));
flag_list
->
push_back
(
tsl
::
Flag
(
"xla_gpu_enable_persistent_temp_buffers"
,
...
...
tensorflow/compiler/xla/service/gpu/runtime/executable.cc
浏览文件 @
279f26c7
...
...
@@ -386,8 +386,11 @@ Status GpuRuntimeExecutable::Execute(
conv_runners_
(
executor
)
->
snapshot
();
#if GOOGLE_CUDA
std
::
shared_ptr
<
StreamExecutorGraphInstances
>
executor_graphs
=
graph_instances_
(
executor
);
StreamExecutorGraphInstances
::
Snapshot
graph_instances
=
graph_instances_
(
executor
)
->
snapshot
();
executor_graphs
->
snapshot
();
CapturedFunctionExecutionCount
::
Snapshot
execution_count
=
captured_function_counts_
(
executor
)
->
snapshot
();
#endif // GOOGLE_CUDA
...
...
@@ -451,7 +454,8 @@ Status GpuRuntimeExecutable::Execute(
}
if
(
auto
instantiated
=
graph_instances_
.
InstantiateAllGraphs
(
run_options
,
executable
,
user_data
,
device_ptr
);
run_options
,
executable
,
user_data
,
device_ptr
,
debug_options_
.
xla_gpu_cuda_graph_eviction_timeout_seconds
());
!
instantiated
.
ok
())
{
return
InternalError
(
"Failed to instantiate CUDA graphs: %s"
,
instantiated
.
message
());
...
...
tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc
浏览文件 @
279f26c7
...
...
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h"
#include <algorithm>
#include <array>
#include <atomic>
#include <cstddef>
...
...
@@ -73,15 +74,39 @@ static absl::StatusOr<OwnedCudaGraph> CaptureGraph(
// CUDA graphs caching.
//===----------------------------------------------------------------------===//
static
absl
::
Mutex
*
GetGraphInstancesMutex
()
{
static
auto
*
mu
=
new
absl
::
Mutex
();
return
mu
;
}
struct
GraphInstances
::
Impl
{
struct
State
{
// A flag signalling if `InstantiateAllGraphs` was already called and we
// have all Gpu graph instantiated ahead of time.
bool
instantiated
=
false
;
// Last time graph instances were used by a particular stream executor.
uint64_t
last_use_micros
=
0
;
std
::
shared_ptr
<
StreamExecutorGraphInstances
>
instances
=
std
::
make_shared
<
StreamExecutorGraphInstances
>
();
};
// XLA module name that owns graph instances. We use it only to produce logs
// that can be attributed back to XLA executables.
std
::
string
module_name
;
// Number of graphs in the parent module.
int64_t
num_graphs
=
0
;
mutable
absl
::
Mutex
mu
;
absl
::
node_hash_map
<
se
::
StreamExecutor
*
,
State
>
graphs
ABSL_GUARDED_BY
(
mu
);
};
// Keep track of instantiated graphs on each StreamExecutor, we use this
// information in the graph eviction policy.
using
GraphInstancesState
=
absl
::
flat_hash_map
<
se
::
StreamExecutor
*
,
int64_t
>
;
static
absl
::
Mutex
*
GetGraphInstancesStateMutex
()
{
static
auto
*
mu
=
new
absl
::
Mutex
();
return
mu
;
}
static
GraphInstancesState
&
GetGraphInstancesState
()
{
static
auto
*
state
=
new
GraphInstancesState
();
return
*
state
;
...
...
@@ -89,38 +114,121 @@ static GraphInstancesState& GetGraphInstancesState() {
static
int64_t
NotifyGraphInstancesCreated
(
se
::
StreamExecutor
*
executor
,
int64_t
num_graphs
)
{
absl
::
MutexLock
lock
(
GetGraphInstancesMutex
());
absl
::
MutexLock
lock
(
GetGraphInstances
State
Mutex
());
return
GetGraphInstancesState
()[
executor
]
+=
num_graphs
;
}
static
int64_t
NotifyGraphInstancesDestroyed
(
se
::
StreamExecutor
*
executor
,
int64_t
num_graphs
)
{
absl
::
MutexLock
lock
(
GetGraphInstancesMutex
());
absl
::
MutexLock
lock
(
GetGraphInstances
State
Mutex
());
return
GetGraphInstancesState
()[
executor
]
-=
num_graphs
;
}
// We keep track of all graph instances in the process, to implement graph
// eviction on OOM. Graph instances owned by GpuExecutable, so we rely on
// weak ptr to check if they are still alive.
using
GraphInstancesVec
=
std
::
vector
<
std
::
weak_ptr
<
GraphInstances
::
Impl
>>
;
static
absl
::
Mutex
*
GetGraphInstancesVecMutex
()
{
static
auto
*
mu
=
new
absl
::
Mutex
();
return
mu
;
}
static
GraphInstancesVec
&
GetGraphInstancesVec
()
{
static
auto
*
vec
=
new
GraphInstancesVec
();
return
*
vec
;
}
static
void
AddGraphInstances
(
std
::
weak_ptr
<
GraphInstances
::
Impl
>
impl
)
{
absl
::
MutexLock
lock
(
GetGraphInstancesVecMutex
());
GetGraphInstancesVec
().
push_back
(
std
::
move
(
impl
));
}
// Evicts all graphs for a given executor in the current process.
static
void
EvictAllGraphs
(
se
::
StreamExecutor
*
executor
,
std
::
optional
<
uint64_t
>
eviction_timeout_seconds
=
std
::
nullopt
)
{
LOG
(
WARNING
)
<<
"Evict "
<<
(
eviction_timeout_seconds
.
has_value
()
?
"timed out"
:
"all"
)
<<
" gpu graphs from executor "
<<
executor
;
TraceMe
trace_instantiation
([
&
]
{
return
TraceMeEncode
(
"cuda.graph.evict_all_graphs"
,
{{
"device_ordinal"
,
executor
->
device_ordinal
()}});
});
absl
::
MutexLock
lock
(
GetGraphInstancesVecMutex
());
auto
&
vec
=
GetGraphInstancesVec
();
// Erase all expired graph instances.
vec
.
erase
(
std
::
remove_if
(
vec
.
begin
(),
vec
.
end
(),
[](
auto
&
weak_ptr
)
{
return
weak_ptr
.
expired
();
}),
vec
.
end
());
auto
timed_out
=
[
&
](
GraphInstances
::
Impl
::
State
&
state
)
->
bool
{
auto
diff
=
tsl
::
Env
::
Default
()
->
NowMicros
()
-
state
.
last_use_micros
;
return
(
diff
/
(
1000
*
1000
))
>
*
eviction_timeout_seconds
;
};
for
(
auto
&
weak_ptr
:
vec
)
{
auto
ptr
=
weak_ptr
.
lock
();
if
(
!
ptr
)
continue
;
if
(
!
ptr
->
mu
.
TryLock
())
continue
;
auto
it
=
ptr
->
graphs
.
find
(
executor
);
if
(
it
==
ptr
->
graphs
.
end
())
{
ptr
->
mu
.
Unlock
();
continue
;
}
// If we have a timeout value, than check it first, otherwise always evict
// graphs for a given executor.
bool
is_timed_out
=
timed_out
(
it
->
second
);
if
(
eviction_timeout_seconds
.
has_value
()
&&
!
is_timed_out
)
{
ptr
->
mu
.
Unlock
();
continue
;
}
if
(
ptr
->
num_graphs
>
0
)
{
VLOG
(
3
)
<<
"Evict "
<<
ptr
->
num_graphs
<<
" graphs for: @"
<<
ptr
->
module_name
<<
" at executor: "
<<
executor
<<
" (timed_out = "
<<
is_timed_out
<<
")."
<<
" Total remaining graphs at given executor: "
<<
NotifyGraphInstancesDestroyed
(
executor
,
ptr
->
num_graphs
);
}
ptr
->
graphs
.
erase
(
it
);
ptr
->
mu
.
Unlock
();
}
}
GraphInstances
::
GraphInstances
(
std
::
string
module_name
,
int64_t
num_graphs
)
:
impl_
(
std
::
make_shared
<
Impl
>
())
{
impl_
->
module_name
=
std
::
move
(
module_name
);
impl_
->
num_graphs
=
num_graphs
;
VLOG
(
3
)
<<
"Construct graph instances cache for: @"
<<
impl_
->
module_name
<<
" (num_graphs = "
<<
impl_
->
num_graphs
<<
")"
;
if
(
impl_
->
num_graphs
>
0
)
{
VLOG
(
3
)
<<
"Construct graph instances cache for: @"
<<
impl_
->
module_name
<<
" (num_graphs = "
<<
impl_
->
num_graphs
<<
")"
;
}
AddGraphInstances
(
impl_
);
}
GraphInstances
::~
GraphInstances
()
{
VLOG
(
3
)
<<
"Destroy graph instances cache for: @"
<<
impl_
->
module_name
<<
" (num_graphs = "
<<
impl_
->
num_graphs
<<
")"
;
absl
::
MutexLock
lock
(
&
impl_
->
mu
);
for
(
auto
&
[
executor
,
state
]
:
impl_
->
graphs
)
{
VLOG
(
3
)
<<
"Destroy "
<<
impl_
->
num_graphs
<<
" graphs for: @"
<<
impl_
->
module_name
<<
" at executor: "
<<
executor
<<
". Total remaining graphs at given executor: "
<<
NotifyGraphInstancesDestroyed
(
executor
,
impl_
->
num_graphs
);
if
(
impl_
->
num_graphs
>
0
)
{
VLOG
(
3
)
<<
"Destroy graph instances cache for: @"
<<
impl_
->
module_name
<<
" (num_graphs = "
<<
impl_
->
num_graphs
<<
")"
;
absl
::
MutexLock
lock
(
&
impl_
->
mu
);
for
(
auto
&
[
executor
,
state
]
:
impl_
->
graphs
)
{
VLOG
(
3
)
<<
"Destroy "
<<
impl_
->
num_graphs
<<
" graphs for: @"
<<
impl_
->
module_name
<<
" at executor: "
<<
executor
<<
". Total remaining graphs at given executor: "
<<
NotifyGraphInstancesDestroyed
(
executor
,
impl_
->
num_graphs
);
}
}
}
StreamExecutorGraphInstances
*
GraphInstances
::
operator
()(
std
::
shared_ptr
<
StreamExecutorGraphInstances
>
GraphInstances
::
operator
()(
se
::
StreamExecutor
*
executor
)
{
absl
::
MutexLock
lock
(
&
impl_
->
mu
);
...
...
@@ -132,9 +240,9 @@ StreamExecutorGraphInstances* GraphInstances::operator()(
<<
NotifyGraphInstancesCreated
(
executor
,
impl_
->
num_graphs
);
}
State
&
state
=
it
.
first
->
second
;
Impl
::
State
&
state
=
it
.
first
->
second
;
state
.
last_use_micros
=
tsl
::
Env
::
Default
()
->
NowMicros
();
return
&
state
.
instances
;
return
state
.
instances
;
}
bool
GraphInstances
::
InstantiatedAllGraphs
(
...
...
@@ -149,22 +257,29 @@ bool GraphInstances::InstantiatedAllGraphs(
Status
GraphInstances
::
InstantiateAllGraphs
(
const
ServiceExecutableRunOptions
*
run_options
,
const
Executable
&
executable
,
const
CustomCall
::
UserData
&
user_data
,
void
*
ptr
)
{
void
*
ptr
,
std
::
optional
<
uint64_t
>
eviction_timeout_seconds
)
{
// We have only "main" function in the executable.
if
(
executable
.
num_functions
()
==
1
)
return
OkStatus
();
absl
::
MutexLock
lock
(
&
impl_
->
mu
);
se
::
StreamExecutor
*
executor
=
run_options
->
stream
()
->
parent
();
State
&
state
=
impl_
->
graphs
[
executor
];
Impl
::
State
&
state
=
impl_
->
graphs
[
executor
];
// All Gpu graphs are already instantiated for a given executor.
if
(
state
.
instantiated
)
return
OkStatus
();
TraceMe
trace
(
"cuda.graph.instantiate_all"
);
// Initialize graph instances snapshot for a given executor.
StreamExecutorGraphInstances
::
Snapshot
instances
=
state
.
instances
.
snapshot
();
// Evict all timeout graphs before trying to instantiate new ones.
EvictAllGraphs
(
executor
,
eviction_timeout_seconds
);
// We'll retry graph instantiation on OOM errors after evicting all graphs
// instantiated on `executor`.
int32_t
num_retries
=
0
;
StreamExecutorGraphInstances
::
Snapshot
instances
=
state
.
instances
->
snapshot
();
// Instantiate all Gpu graphs by calling graph capture functions with fake
// arguments. Once we'll execute them first time for real, they'll be updated
...
...
@@ -217,9 +332,19 @@ Status GraphInstances::InstantiateAllGraphs(
return
GraphInstance
(
0
,
std
::
move
(
e
));
};
TF_ASSIGN_OR_RETURN
(
GraphInstance
*
instance
,
instances
.
GetOrCreate
(
ordinal
,
instantiate
));
(
void
)
instance
;
absl
::
StatusOr
<
GraphInstance
*>
instance
=
instances
.
GetOrCreate
(
ordinal
,
instantiate
);
// Retry on OOM error after evicting all graphs from executor.
if
(
instance
.
status
().
code
()
==
absl
::
StatusCode
::
kResourceExhausted
&&
num_retries
++
==
0
)
{
EvictAllGraphs
(
executor
);
--
ordinal
;
// we'll try to instantiate the same graph one more time
continue
;
}
// Otherwise return an error to the caller.
if
(
!
instance
.
ok
())
return
instance
.
status
();
#endif // GOOGLE_CUDA
}
...
...
tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h
浏览文件 @
279f26c7
...
...
@@ -18,6 +18,7 @@ limitations under the License.
#include <atomic>
#include <memory>
#include <optional>
#include <string>
#include <utility>
...
...
@@ -87,48 +88,32 @@ class StreamExecutorGraphInstances
// end up with thousands of unused (or rarely used) graphs in device memory.
class
GraphInstances
{
public:
struct
Impl
;
GraphInstances
(
std
::
string
module_name
,
int64_t
num_graphs
);
~
GraphInstances
();
StreamExecutorGraphInstances
*
operator
()(
se
::
StreamExecutor
*
executor
);
std
::
shared_ptr
<
StreamExecutorGraphInstances
>
operator
()(
se
::
StreamExecutor
*
executor
);
// Instantiates all Gpu graphs defined by the given executable using user
// provided run options. This guarantees that once we start execution, all Gpu
// graphs are ready, and will only require cheap update operation and will not
// require allocating new resources (we avoid non deterministic OOM errors).
Status
InstantiateAllGraphs
(
const
ServiceExecutableRunOptions
*
run_options
,
const
runtime
::
Executable
&
executable
,
const
runtime
::
CustomCall
::
UserData
&
user_data
,
void
*
ptr
);
//
// If timeout is not nullopt it will evict all previously instantiated graphs
// that were used more than `eviction_timeout_seconds` seconds ago.
Status
InstantiateAllGraphs
(
const
ServiceExecutableRunOptions
*
run_options
,
const
runtime
::
Executable
&
executable
,
const
runtime
::
CustomCall
::
UserData
&
user_data
,
void
*
ptr
,
std
::
optional
<
uint64_t
>
eviction_timeout_seconds
=
std
::
nullopt
);
// Returns true if all Gpu graphs were already instantiated.
bool
InstantiatedAllGraphs
(
const
ServiceExecutableRunOptions
*
run_options
,
const
runtime
::
Executable
&
executable
);
private:
struct
State
{
// A flag signalling if `InstantiateAllGraphs` was already called and we
// have all Gpu graph instantiated ahead of time.
bool
instantiated
=
false
;
// Last time graph instances were used by a particular stream executor.
uint64_t
last_use_micros
=
0
;
StreamExecutorGraphInstances
instances
;
};
struct
Impl
{
// XLA module name that owns graph instances. We use it only to produce logs
// that can be attributed back to XLA executables.
std
::
string
module_name
;
// Number of graphs in the parent module.
int64_t
num_graphs
;
mutable
absl
::
Mutex
mu
;
absl
::
node_hash_map
<
se
::
StreamExecutor
*
,
State
>
graphs
ABSL_GUARDED_BY
(
mu
);
};
std
::
shared_ptr
<
Impl
>
impl_
;
};
...
...
tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc
浏览文件 @
279f26c7
...
...
@@ -26,12 +26,6 @@ limitations under the License.
namespace
stream_executor
{
namespace
gpu
{
template
<
typename
...
Args
>
static
tsl
::
Status
InternalError
(
const
absl
::
FormatSpec
<
Args
...
>&
format
,
const
Args
&
...
args
)
{
return
tsl
::
errors
::
Internal
(
absl
::
StrFormat
(
format
,
args
...));
}
//===----------------------------------------------------------------------===//
// RAII helpers for CUDA graph types.
//===----------------------------------------------------------------------===//
...
...
@@ -80,8 +74,8 @@ tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) {
auto
err
=
cudaGraphExecUpdate
(
get
(),
graph
.
get
(),
&
updated
);
if
(
err
!=
cudaSuccess
||
updated
.
result
!=
cudaGraphExecUpdateSuccess
)
return
InternalError
(
"failed to update cuda graph: %s"
,
cudaGetErrorString
(
err
));
return
absl
::
InternalError
(
absl
::
StrFormat
(
"failed to update cuda graph: %s"
,
cudaGetErrorString
(
err
)
));
#else
cudaGraphExecUpdateResult
updated
;
...
...
@@ -89,8 +83,8 @@ tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) {
auto
err
=
cudaGraphExecUpdate
(
get
(),
graph
.
get
(),
&
error_node
,
&
updated
);
if
(
err
!=
cudaSuccess
||
updated
!=
cudaGraphExecUpdateSuccess
)
return
InternalError
(
"Failed to update cuda graph %s"
,
cudaGetErrorString
(
err
));
return
absl
::
InternalError
(
absl
::
StrFormat
(
"Failed to update cuda graph %s"
,
cudaGetErrorString
(
err
)
));
#endif
return
tsl
::
OkStatus
();
...
...
@@ -103,8 +97,8 @@ tsl::Status OwnedCudaGraphExec::Launch(stream_executor::Stream* stream) {
if
(
auto
err
=
cudaGraphLaunch
(
get
(),
AsGpuStreamValue
(
stream
));
err
!=
cudaSuccess
)
return
InternalError
(
"failed to run cuda graph: %s"
,
cudaGetErrorString
(
err
));
return
absl
::
InternalError
(
absl
::
StrFormat
(
"failed to run cuda graph: %s"
,
cudaGetErrorString
(
err
)
));
return
tsl
::
OkStatus
();
}
...
...
@@ -133,20 +127,20 @@ tsl::StatusOr<OwnedCudaGraph> CaptureCudaGraph(
// Capture graph constructed by the exported graph capture function.
if
(
auto
err
=
cudaStreamBeginCapture
(
gpu_stream
,
mode
);
err
!=
cudaSuccess
)
return
InternalError
(
"stream begin capture failed: %s"
,
cudaGetErrorString
(
err
));
return
absl
::
InternalError
(
absl
::
StrFormat
(
"stream begin capture failed: %s"
,
cudaGetErrorString
(
err
)
));
// Call into graph capture function.
auto
captured
=
capture
();
// Always stop capturing the stream before checking `captured` result.
if
(
auto
err
=
cudaStreamEndCapture
(
gpu_stream
,
&
graph
);
err
!=
cudaSuccess
)
return
InternalError
(
"stream end capture failed: %s"
,
cudaGetErrorString
(
err
));
return
absl
::
InternalError
(
absl
::
StrFormat
(
"stream end capture failed: %s"
,
cudaGetErrorString
(
err
)
));
if
(
!
captured
.
ok
())
return
InternalError
(
"failed to capture CUDA graph: %s"
,
captured
.
message
(
));
return
absl
::
InternalError
(
absl
::
StrFormat
(
"failed to capture CUDA graph: %s"
,
captured
.
message
()
));
VLOG
(
5
)
<<
"Captured CUDA graph "
<<
graph
;
...
...
@@ -195,8 +189,16 @@ tsl::StatusOr<OwnedCudaGraphExec> InstantiateCudaGraph(OwnedCudaGraph graph) {
if
(
auto
err
=
cudaGraphInstantiate
(
&
exec
,
&*
graph
,
nullptr
,
nullptr
,
0
);
#endif
err
!=
cudaSuccess
)
{
return
InternalError
(
"graph instantiation failed: %s"
,
cudaGetErrorString
(
err
));
if
(
err
==
cudaErrorMemoryAllocation
)
{
// OOM is a recoverable error, we evict all instantiated cuda graphs to
// free up some space (see graph launch.cc). Clear error status.
return
absl
::
ResourceExhaustedError
(
absl
::
StrFormat
(
"graph instantiation failed: %s"
,
cudaGetErrorString
(
cudaGetLastError
())));
}
else
{
return
absl
::
InternalError
(
absl
::
StrFormat
(
"graph instantiation failed: %s"
,
cudaGetErrorString
(
err
)));
}
}
size_t
id
=
CudaGraphSupport
::
NotifyGraphExecCreated
();
...
...
@@ -211,8 +213,8 @@ tsl::StatusOr<bool> IsStreamCapturing(stream_executor::Stream* stream) {
cudaError_t
err
=
cudaStreamIsCapturing
(
stream_executor
::
gpu
::
AsGpuStreamValue
(
stream
),
&
capture_status
);
if
(
err
!=
cudaSuccess
)
{
return
InternalError
(
"Failed to get stream's capture status: %s"
,
cudaGetErrorString
(
err
));
return
absl
::
InternalError
(
absl
::
StrFormat
(
"Failed to get stream's capture status: %s"
,
cudaGetErrorString
(
err
)
));
}
return
capture_status
==
cudaStreamCaptureStatusActive
;
...
...
tensorflow/compiler/xla/xla.proto
浏览文件 @
279f26c7
...
...
@@ -454,6 +454,11 @@ message DebugOptions {
// Identify concurrent regions in cuda graphs and execute them concurrently.
bool
xla_gpu_cuda_graph_enable_concurrent_region
=
215
;
// Timeout in seconds to evict instantiated Gpu graphs from device. When XLA
// instantiates new Gpu graphs, it evicts graphs that were not recently
// executed to free space on device.
int32
xla_gpu_cuda_graph_eviction_timeout_seconds
=
230
;
// Allocate temp buffers once during the first execution of an executable.
// Reuse the allocated buffers in subsequent executions. Executables cannot
// run concurrently if this is enabled.
...
...
@@ -572,7 +577,7 @@ message DebugOptions {
int32
xla_gpu_triton_fusion_level
=
229
;
// Next id: 23
0
// Next id: 23
1
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录