未验证 提交 4f4ce1f2 编写于 作者: L Lyon 提交者: GitHub

Support localtensor slice (#4985)

* add scalar input support

* format

* register local tensor slice methods
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 09552182
......@@ -339,6 +339,7 @@ class Tensor:
def backward(self, gradient=None, retain_graph=False, create_graph=False):
flow.autograd.backward(self, gradient, retain_graph, create_graph)
@register_local_tensor_method()
def _get_slice_obj(self, key):
def get_or_default(x, default):
return x if x is not None else default
......@@ -385,6 +386,7 @@ class Tensor:
return starts, stops, steps, shape
@_auto_determine
@register_local_tensor_method()
def __getitem__(self, key):
# TODO: support inplace __getitem__
start, stop, step, _ = self._get_slice_obj(key)
......@@ -392,6 +394,7 @@ class Tensor:
return res
@_auto_determine
@register_local_tensor_method()
def __setitem__(self, key, value):
start, stop, step, shape = self._get_slice_obj(key)
if isinstance(value, (int, float)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册