未验证 提交 cfc9bf76 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] fix slice's input mistake (#44855)

* [Eager] fix slice's input mistake

* add tests for slice
上级 2140e825
...@@ -635,7 +635,7 @@ class TestSliceApiEager(unittest.TestCase): ...@@ -635,7 +635,7 @@ class TestSliceApiEager(unittest.TestCase):
axes=axes, axes=axes,
starts=paddle.to_tensor(starts), starts=paddle.to_tensor(starts),
ends=paddle.to_tensor(ends)) ends=paddle.to_tensor(ends))
self.assertTrue(np.array_equal(a_1.numpy(), a_2.numpy()))
a_1.backward() a_1.backward()
grad_truth = paddle.zeros_like(a) grad_truth = paddle.zeros_like(a)
grad_truth[-3:3, 0:2, 2:4] = 1 grad_truth[-3:3, 0:2, 2:4] = 1
......
...@@ -208,7 +208,7 @@ def slice(input, axes, starts, ends): ...@@ -208,7 +208,7 @@ def slice(input, axes, starts, ends):
if isinstance(item, tmp_tensor_type) else item for item in ends if isinstance(item, tmp_tensor_type) else item for item in ends
] ]
elif isinstance(ends, tmp_tensor_type): elif isinstance(ends, tmp_tensor_type):
etensor_t = ends.numpy() tensor_t = ends.numpy()
ends = [ele for ele in tensor_t] ends = [ele for ele in tensor_t]
infer_flags = list(-1 for i in range(len(axes))) infer_flags = list(-1 for i in range(len(axes)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册