未验证 提交 b61fa16a 编写于 作者: H hong 提交者: GitHub

add split backward yaml (#41746)

上级 c9c03e7b
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest, convert_float_to_uint16
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
from paddle.fluid.framework import _test_eager_guard
class TestSplitOp(OpTest):
......@@ -402,12 +403,30 @@ class API_TestDygraphSplit(unittest.TestCase):
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("int32")
# input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1)
input = paddle.to_tensor(input_1)
x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1)
x0_out = x0.numpy()
x1_out = x1.numpy()
x2_out = x2.numpy()
ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1)
with _test_eager_guard():
# input is a variable which shape is [4, 6, 6]
input = paddle.to_tensor(input_1)
input.stop_gradient = False
x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1)
eager_x0_out = x0.numpy()
eager_x1_out = x1.numpy()
eager_x2_out = x2.numpy()
loss = x0.sum()
loss.backward()
manul_grad = np.zeros_like(input_1)
manul_grad[:, :2, :] = 1
self.assertTrue(np.allclose(input.gradient(), manul_grad))
self.assertTrue(np.allclose(ex_x0, eager_x0_out))
self.assertTrue(np.allclose(ex_x1, eager_x1_out))
self.assertTrue(np.allclose(ex_x2, eager_x2_out))
self.assertTrue(np.allclose(ex_x0, x0_out))
self.assertTrue(np.allclose(ex_x1, x1_out))
self.assertTrue(np.allclose(ex_x2, x2_out))
......@@ -416,7 +435,7 @@ class API_TestDygraphSplit(unittest.TestCase):
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("bool")
# input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1)
input = paddle.to_tensor(input_1)
x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1)
x0_out = x0.numpy()
x1_out = x1.numpy()
......@@ -430,7 +449,7 @@ class API_TestDygraphSplit(unittest.TestCase):
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("int32")
# input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1)
input = paddle.to_tensor(input_1)
num1 = paddle.full(shape=[1], fill_value=2, dtype='int32')
x0, x1, x2 = paddle.split(
input, num_or_sections=[num1, 2, 2], axis=1)
......@@ -446,7 +465,7 @@ class API_TestDygraphSplit(unittest.TestCase):
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("int32")
# input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1)
input = paddle.to_tensor(input_1)
num1 = paddle.full(shape=[1], fill_value=1, dtype='int32')
x0, x1, x2 = paddle.split(
input, num_or_sections=[2, 2, 2], axis=num1)
......
......@@ -1917,6 +1917,7 @@
args : (Tensor x, IntArray num_or_sections, Scalar(int) axis)
output : Tensor[]
invoke : split_impl(x, num_or_sections, axis)
backward : split_grad
- api : sqrt
args : (Tensor x)
......
......@@ -1523,7 +1523,7 @@
- backward_api : split_grad
forward : split (Tensor x, IntArray num_or_sections, Scalar axis) -> Tensor[](out)
args : (Tensor[] out_grad, Scalar axis)
args : (Tensor[] out_grad, Scalar axis = -1)
output : Tensor(x_grad)
invoke : concat( out_grad, axis)
# TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册