未验证 提交 57d629a1 编写于 作者: J Joe Mayer 提交者: GitHub

Empty tensor size check (#4186)

* Size for transformer engine.

* adding kwargs

* args tuple

* format updates
上级 8145b5e4
......@@ -250,9 +250,11 @@ def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.
def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable:
def new_tensor(cls, *args) -> Tensor:
def new_tensor(cls, *args, **kwargs) -> Tensor:
device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
tensor = _orig_torch_empty(0, device=device).new_empty(*args)
if not args:
args = (0, )
tensor = _orig_torch_empty(0, device=device).new_empty(*args, **kwargs)
if tensor.is_floating_point():
tensor = tensor.to(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册