From ee3540881debe9223c67ef56b808636a690de412 Mon Sep 17 00:00:00 2001 From: HansBug Date: Tue, 2 Aug 2022 15:03:43 +0800 Subject: [PATCH] dev(hansbug): add stream.py --- treetensor/torch/stream.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 treetensor/torch/stream.py diff --git a/treetensor/torch/stream.py b/treetensor/torch/stream.py new file mode 100644 index 000000000..b7ebf4f46 --- /dev/null +++ b/treetensor/torch/stream.py @@ -0,0 +1,35 @@ +import random +from typing import Optional, List + +import torch + +_stream_pool: Optional[List[torch.cuda.Stream]] = None +_global_streams: Optional[List[torch.cuda.Stream]] = None + +__all__ = [ + 'stream', +] + + +def stream(cnt): + assert torch.cuda.is_available(), "CUDA is not supported." + + global _stream_pool, _global_streams + if _stream_pool is None: + _stream_pool = [torch.cuda.current_stream()] + + if cnt is None: # close stream support by + _global_streams = None + else: # use the given number of streams + while len(_stream_pool) < cnt: + _stream_pool.append(torch.cuda.Stream()) + + _global_streams = _stream_pool[:cnt] + + +def stream_call(func, *args, **kwargs): + if _global_streams is not None: + with torch.cuda.stream(random.choice(_global_streams)): + return func(*args, **kwargs) + else: + return func(*args, **kwargs) -- GitLab