未验证 提交 d9fac780 编写于 作者: C Charles-hit 提交者: GitHub

support slice op backward refuse forward and add high level unit test (#45960)

上级 da546c88
...@@ -2138,11 +2138,7 @@ ...@@ -2138,11 +2138,7 @@
forward : slice_grad (Tensor input, Tensor grad_out, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(grad_input) forward : slice_grad (Tensor input, Tensor grad_out, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(grad_input)
args : (Tensor grad_input_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) args : (Tensor grad_input_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor(grad_out_grad) output : Tensor(grad_out_grad)
infer_meta : invoke : slice(grad_input_grad, axes, starts, ends, infer_flags, decrease_axis)
func : UnchangedInferMeta
param : [grad_input_grad]
kernel :
func : slice
- backward_api : slice_grad - backward_api : slice_grad
forward : slice (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(out) forward : slice (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(out)
......
...@@ -22,6 +22,9 @@ import paddle.fluid as fluid ...@@ -22,6 +22,9 @@ import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard, _enable_legacy_dygraph from paddle.fluid.framework import _test_eager_guard, _enable_legacy_dygraph
import gradient_checker
from decorator_helper import prog_scope
import paddle.fluid.layers as layers
paddle.enable_static() paddle.enable_static()
...@@ -840,6 +843,92 @@ class TestImperativeCUDAPinnedInput(unittest.TestCase): ...@@ -840,6 +843,92 @@ class TestImperativeCUDAPinnedInput(unittest.TestCase):
self.assertEqual(sliced.shape, [2, 70, 80]) self.assertEqual(sliced.shape, [2, 70, 80])
class TestSliceDoubleGradCheck(unittest.TestCase):
def slice_wrapper(self, x):
return paddle.slice(x[0],
axes=[0, 1, 2],
starts=[-3, 0, 2],
ends=[3, 2, 4])
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
eps = 0.005
dtype = np.float32
data = layers.data('data', [4, 5, 6], False, dtype)
data.persistable = True
out = paddle.slice(data,
axes=[0, 1, 2],
starts=[-3, 0, 2],
ends=[3, 2, 4])
data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype)
gradient_checker.double_grad_check([data],
out,
x_init=[data_arr],
place=place,
eps=eps)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.double_grad_check_for_dygraph(self.slice_wrapper,
[data],
out,
x_init=[data_arr],
place=place)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestSliceTripleGradCheck(unittest.TestCase):
def slice_wrapper(self, x):
return paddle.slice(x[0],
axes=[0, 1, 2],
starts=[-3, 0, 2],
ends=[3, 2, 4])
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
eps = 0.005
dtype = np.float32
data = layers.data('data', [4, 5, 6], False, dtype)
data.persistable = True
out = paddle.slice(data,
axes=[0, 1, 2],
starts=[-3, 0, 2],
ends=[3, 2, 4])
data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype)
gradient_checker.triple_grad_check([data],
out,
x_init=[data_arr],
place=place,
eps=eps)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph(self.slice_wrapper,
[data],
out,
x_init=[data_arr],
place=place)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册