未验证 提交 fecea4c5 编写于 作者: Z zqw_1997 提交者: GitHub

[Cherry-pick 2.5][Zero-Dim] paddle.static.data, squeeze, unbind, unstack,...

[Cherry-pick 2.5][Zero-Dim]  paddle.static.data, squeeze, unbind, unstack, gather_nd and einsum support 0D (#53602)

* add test cases, test=allcase

* fix test cases, test=allcase

* fix test cases, test=allcase

* assert_allclose, test=allcase

* 1e-5 to 1e-4, test=allcase

* change rtol from 1e-4 to 1e-3, test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* fix test cases, test=allcase

* fix test cases, test=allcase

* modify the test_squeeze to not use Tensor type axis, test=allcase

* add grad check for unbind and unstack, test=allcase

* check for squeeze axis tensor type, test=allcase

* fix bug, test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase
上级 3a247cba
...@@ -3762,6 +3762,9 @@ void SqueezeInferMeta(const MetaTensor& x, ...@@ -3762,6 +3762,9 @@ void SqueezeInferMeta(const MetaTensor& x,
if (!config.is_runtime && axes.FromTensor()) { if (!config.is_runtime && axes.FromTensor()) {
// compile time infershape, set all elements to -1. // compile time infershape, set all elements to -1.
int output_size = x.dims().size() - axes.GetData().size(); int output_size = x.dims().size() - axes.GetData().size();
if (x.dims().size() == 0 && output_size == -1) {
output_size = 0;
}
std::vector<int64_t> vec_out_dims(output_size, -1); std::vector<int64_t> vec_out_dims(output_size, -1);
out->set_dims(phi::make_ddim(vec_out_dims)); out->set_dims(phi::make_ddim(vec_out_dims));
} else { } else {
......
...@@ -580,7 +580,7 @@ def _as_lodtensor(data, place, dtype=None): ...@@ -580,7 +580,7 @@ def _as_lodtensor(data, place, dtype=None):
else dtype else dtype
) )
if np.isscalar(data): if np.isscalar(data):
data = np.array([data]).astype(dtype) data = np.array(data).astype(dtype)
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
data = np.array(data) data = np.array(data)
if data.dtype == np.object_: if data.dtype == np.object_:
......
...@@ -96,7 +96,7 @@ def create_paddle_case(op_type, callback): ...@@ -96,7 +96,7 @@ def create_paddle_case(op_type, callback):
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[4], dtype='int64') x = paddle.static.data(name='x', shape=[4], dtype='int64')
y = paddle.static.data(name='y', shape=[1], dtype='int64') y = paddle.static.data(name='y', shape=[], dtype='int64')
op = eval("paddle.%s" % (self.op_type)) op = eval("paddle.%s" % (self.op_type))
out = op(x, y) out = op(x, y)
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
......
...@@ -31,6 +31,13 @@ class TestApiStaticDataError(unittest.TestCase): ...@@ -31,6 +31,13 @@ class TestApiStaticDataError(unittest.TestCase):
x3 = paddle.static.data(name="x3", shape=[2, 25]) x3 = paddle.static.data(name="x3", shape=[2, 25])
self.assertEqual(x3.dtype, core.VarDesc.VarType.FP64) self.assertEqual(x3.dtype, core.VarDesc.VarType.FP64)
def test_0D(self):
with program_guard(Program(), Program()):
x1 = paddle.static.data(name="x1_0D", shape=[])
self.assertEqual(x1.dtype, core.VarDesc.VarType.FP32)
x2 = paddle.static.data(name="x2_0D", shape=(), dtype="bool")
self.assertEqual(x2.dtype, core.VarDesc.VarType.BOOL)
def test_error(self): def test_error(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -66,7 +66,7 @@ class TestDeg2radAPI(unittest.TestCase): ...@@ -66,7 +66,7 @@ class TestDeg2radAPI(unittest.TestCase):
class TestDeg2radAPI2(TestDeg2radAPI): class TestDeg2radAPI2(TestDeg2radAPI):
# Test input data type is int # Test input data type is int
def setUp(self): def setUp(self):
self.x_np = 180 self.x_np = [180]
self.x_shape = [1] self.x_shape = [1]
self.out_np = np.pi self.out_np = np.pi
self.x_dtype = 'int64' self.x_dtype = 'int64'
......
...@@ -22,7 +22,7 @@ from paddle import fluid ...@@ -22,7 +22,7 @@ from paddle import fluid
class TestExecutor(unittest.TestCase): class TestExecutor(unittest.TestCase):
def net(self): def net(self):
lr = paddle.static.data(name="lr", shape=[1], dtype='float32') lr = paddle.static.data(name="lr", shape=[], dtype='float32')
x = paddle.static.data(name="x", shape=[None, 1], dtype='float32') x = paddle.static.data(name="x", shape=[None, 1], dtype='float32')
y = paddle.static.data(name="y", shape=[None, 1], dtype='float32') y = paddle.static.data(name="y", shape=[None, 1], dtype='float32')
y_predict = paddle.static.nn.fc(x, size=1) y_predict = paddle.static.nn.fc(x, size=1)
......
...@@ -25,8 +25,8 @@ paddle.enable_static() ...@@ -25,8 +25,8 @@ paddle.enable_static()
class TestGcdAPI(unittest.TestCase): class TestGcdAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.x_np = 12 self.x_np = [12]
self.y_np = 20 self.y_np = [20]
self.x_shape = [1] self.x_shape = [1]
self.y_shape = [1] self.y_shape = [1]
...@@ -81,14 +81,14 @@ class TestGcdAPI3(TestGcdAPI): ...@@ -81,14 +81,14 @@ class TestGcdAPI3(TestGcdAPI):
def setUp(self): def setUp(self):
self.x_np = 0 self.x_np = 0
self.y_np = 20 self.y_np = 20
self.x_shape = [1] self.x_shape = []
self.y_shape = [1] self.y_shape = []
class TestGcdAPI4(TestGcdAPI): class TestGcdAPI4(TestGcdAPI):
def setUp(self): def setUp(self):
self.x_np = 0 self.x_np = [0]
self.y_np = 0 self.y_np = [0]
self.x_shape = [1] self.x_shape = [1]
self.y_shape = [1] self.y_shape = [1]
...@@ -97,5 +97,5 @@ class TestGcdAPI5(TestGcdAPI): ...@@ -97,5 +97,5 @@ class TestGcdAPI5(TestGcdAPI):
def setUp(self): def setUp(self):
self.x_np = 12 self.x_np = 12
self.y_np = -20 self.y_np = -20
self.x_shape = [1] self.x_shape = []
self.y_shape = [1] self.y_shape = []
...@@ -27,8 +27,8 @@ class TestLcmAPI(unittest.TestCase): ...@@ -27,8 +27,8 @@ class TestLcmAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.x_np = 12 self.x_np = 12
self.y_np = 20 self.y_np = 20
self.x_shape = [1] self.x_shape = []
self.y_shape = [1] self.y_shape = []
def test_static_graph(self): def test_static_graph(self):
startup_program = fluid.Program() startup_program = fluid.Program()
...@@ -81,14 +81,14 @@ class TestLcmAPI3(TestLcmAPI): ...@@ -81,14 +81,14 @@ class TestLcmAPI3(TestLcmAPI):
def setUp(self): def setUp(self):
self.x_np = 0 self.x_np = 0
self.y_np = 20 self.y_np = 20
self.x_shape = [1] self.x_shape = []
self.y_shape = [1] self.y_shape = []
class TestLcmAPI4(TestLcmAPI): class TestLcmAPI4(TestLcmAPI):
def setUp(self): def setUp(self):
self.x_np = 0 self.x_np = [0]
self.y_np = 0 self.y_np = [0]
self.x_shape = [1] self.x_shape = [1]
self.y_shape = [1] self.y_shape = [1]
...@@ -97,5 +97,5 @@ class TestLcmAPI5(TestLcmAPI): ...@@ -97,5 +97,5 @@ class TestLcmAPI5(TestLcmAPI):
def setUp(self): def setUp(self):
self.x_np = 12 self.x_np = 12
self.y_np = -20 self.y_np = -20
self.x_shape = [1] self.x_shape = []
self.y_shape = [1] self.y_shape = []
...@@ -141,7 +141,7 @@ class TestPutAlongAxisAPI(unittest.TestCase): ...@@ -141,7 +141,7 @@ class TestPutAlongAxisAPI(unittest.TestCase):
self.place = [paddle.CPUPlace()] self.place = [paddle.CPUPlace()]
self.axis = 0 self.axis = 0
self.value_np = 99.0 self.value_np = 99.0
self.value_shape = [1] self.value_shape = []
self.x_feed = copy.deepcopy(self.x_np) self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0)) self.place.append(paddle.CUDAPlace(0))
...@@ -240,7 +240,7 @@ class TestPutAlongAxisAPICase2(TestPutAlongAxisAPI): ...@@ -240,7 +240,7 @@ class TestPutAlongAxisAPICase2(TestPutAlongAxisAPI):
self.place = [paddle.CPUPlace()] self.place = [paddle.CPUPlace()]
self.axis = 0 self.axis = 0
self.value_np = 99.0 self.value_np = 99.0
self.value_shape = [1] self.value_shape = []
self.x_feed = copy.deepcopy(self.x_np) self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0)) self.place.append(paddle.CUDAPlace(0))
...@@ -258,7 +258,7 @@ class TestPutAlongAxisAPICase3(TestPutAlongAxisAPI): ...@@ -258,7 +258,7 @@ class TestPutAlongAxisAPICase3(TestPutAlongAxisAPI):
self.place = [paddle.CPUPlace()] self.place = [paddle.CPUPlace()]
self.axis = 0 self.axis = 0
self.value_np = 99.0 self.value_np = 99.0
self.value_shape = [1] self.value_shape = []
self.x_feed = copy.deepcopy(self.x_np) self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0)) self.place.append(paddle.CUDAPlace(0))
......
...@@ -65,7 +65,7 @@ class TestRad2degAPI(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestRad2degAPI(unittest.TestCase):
class TestRad2degAPI2(TestRad2degAPI): class TestRad2degAPI2(TestRad2degAPI):
def setUp(self): def setUp(self):
self.x_np = np.pi / 2 self.x_np = [np.pi / 2]
self.x_shape = [1] self.x_shape = [1]
self.out_np = 90 self.out_np = 90
self.x_dtype = 'float32' self.x_dtype = 'float32'
...@@ -83,7 +83,7 @@ class TestRad2degAPI2(TestRad2degAPI): ...@@ -83,7 +83,7 @@ class TestRad2degAPI2(TestRad2degAPI):
class TestRad2degAPI3(TestRad2degAPI): class TestRad2degAPI3(TestRad2degAPI):
# Test input data type is int # Test input data type is int
def setUp(self): def setUp(self):
self.x_np = 1 self.x_np = [1]
self.x_shape = [1] self.x_shape = [1]
self.out_np = 180 / np.pi self.out_np = 180 / np.pi
self.x_dtype = 'int64' self.x_dtype = 'int64'
......
...@@ -83,7 +83,7 @@ class TestTrapezoidAPI(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TestTrapezoidAPI(unittest.TestCase):
) )
if self.dx is not None: if self.dx is not None:
dx = paddle.static.data( dx = paddle.static.data(
name="dx", shape=[1], dtype='float32' name="dx", shape=[], dtype='float32'
) )
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
......
...@@ -29,7 +29,7 @@ class TestUnbind(unittest.TestCase): ...@@ -29,7 +29,7 @@ class TestUnbind(unittest.TestCase):
x_1 = paddle.static.data(shape=[2, 3], dtype='float32', name='x_1') x_1 = paddle.static.data(shape=[2, 3], dtype='float32', name='x_1')
[out_0, out_1] = tensor.unbind(input=x_1, axis=0) [out_0, out_1] = tensor.unbind(input=x_1, axis=0)
input_1 = np.random.random([2, 3]).astype("float32") input_1 = np.random.random([2, 3]).astype("float32")
axis = paddle.static.data(shape=[1], dtype='int32', name='axis') axis = paddle.static.data(shape=[], dtype='int32', name='axis')
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
[res_1, res_2] = exe.run( [res_1, res_2] = exe.run(
...@@ -87,7 +87,7 @@ class TestLayersUnbind(unittest.TestCase): ...@@ -87,7 +87,7 @@ class TestLayersUnbind(unittest.TestCase):
x_1 = paddle.static.data(shape=[2, 3], dtype='float32', name='x_1') x_1 = paddle.static.data(shape=[2, 3], dtype='float32', name='x_1')
[out_0, out_1] = paddle.unbind(input=x_1, axis=0) [out_0, out_1] = paddle.unbind(input=x_1, axis=0)
input_1 = np.random.random([2, 3]).astype("float32") input_1 = np.random.random([2, 3]).astype("float32")
axis = paddle.static.data(shape=[1], dtype='int32', name='axis') axis = paddle.static.data(shape=[], dtype='int32', name='axis')
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
[res_1, res_2] = exe.run( [res_1, res_2] = exe.run(
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# 0D Tensor's shape is always [], numel is 1 # 0D Tensor's shape is always [], numel is 1
# which can be created by paddle.rand([]) # which can be created by paddle.rand([])
import os
import unittest import unittest
import numpy as np import numpy as np
...@@ -1712,6 +1713,75 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1712,6 +1713,75 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [2]) self.assertEqual(out.grad.shape, [2])
def test_gather_nd(self):
x1 = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
x2 = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index1 = paddle.full([1], 1, 'int64')
index2 = paddle.full([2], 1, 'int64')
out1 = paddle.gather_nd(x1, index1)
out2 = paddle.gather_nd(x2, index2)
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
np.testing.assert_array_equal(out1, np.array(3.0))
np.testing.assert_array_equal(out2, np.array(5.0))
self.assertEqual(x1.grad.shape, [5])
self.assertEqual(x2.grad.shape, [2, 3])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
def test_einsum(self):
os.environ['FLAGS_new_einsum'] = "0"
x = paddle.rand([5])
# sum
out1 = paddle.einsum('i->', x)
expect1 = np.einsum('i->', x)
# dot
out2 = paddle.einsum('i,i->', x, x)
expect2 = np.einsum('i,i->', x, x)
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
np.testing.assert_allclose(out1, expect1, rtol=1e-03)
np.testing.assert_allclose(out2, expect2, rtol=1e-03)
def test_einsum_V2(self):
os.environ['FLAGS_new_einsum'] = "1"
x = paddle.rand([5])
# sum
out1 = paddle.einsum('i->', x)
expect1 = np.einsum('i->', x)
# dot
out2 = paddle.einsum('i,i->', x, x)
expect2 = np.einsum('i,i->', x, x)
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
np.testing.assert_allclose(out1, expect1, rtol=1e-03)
np.testing.assert_allclose(out2, expect2, rtol=1e-03)
def test_scatter_1D(self): def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64') index = paddle.full([], 2, 'int64')
...@@ -2324,6 +2394,56 @@ class TestSundryAPI(unittest.TestCase): ...@@ -2324,6 +2394,56 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out1.shape, [2, 3, 12, 12]) self.assertEqual(out1.shape, [2, 3, 12, 12])
self.assertEqual(input_x.grad.shape, [2, 3, 6, 6]) self.assertEqual(input_x.grad.shape, [2, 3, 6, 6])
def test_unstack(self):
x1 = paddle.full([1], 0)
x2 = paddle.full([2], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False
[out1] = paddle.unstack(x1, 0)
out1.retain_grads()
out1.backward()
[out2_1, out2_2] = paddle.unstack(x2, 0)
out2 = paddle.add_n([out2_1, out2_2])
out2.retain_grads()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 0)
self.assertEqual(out2_1.shape, [])
self.assertEqual(out2_1.numpy(), 2)
self.assertEqual(out2_2.shape, [])
self.assertEqual(out2_2.numpy(), 2)
self.assertEqual(x2.grad.shape, [2])
def test_unbind(self):
x1 = paddle.full([1], 0)
x2 = paddle.full([2], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False
[out1] = paddle.unbind(x1, 0)
out1.retain_grads()
out1.backward()
[out2_1, out2_2] = paddle.unbind(x2, 0)
out2 = paddle.add_n([out2_1, out2_2])
out2.retain_grads()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 0)
self.assertEqual(out2_1.shape, [])
self.assertEqual(out2_1.numpy(), 2)
self.assertEqual(out2_2.shape, [])
self.assertEqual(out2_2.numpy(), 2)
self.assertEqual(x2.grad.shape, [2])
def test_maseked_select(self): def test_maseked_select(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -2338,6 +2458,26 @@ class TestSundryAPI(unittest.TestCase): ...@@ -2338,6 +2458,26 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1) self.assertEqual(x.grad.numpy(), 1)
def test_squeeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
x1.retain_grads()
out1 = paddle.squeeze(x1, axis=0)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(x1.grad.shape, [])
x2 = paddle.full([], 3)
x3 = paddle.full([1], 0, dtype='int32')
x2.stop_gradient = False
x2.retain_grads()
out2 = paddle.squeeze(x2, axis=x3)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(x2.grad.shape, [])
def test_unsqueeze(self): def test_unsqueeze(self):
x1 = paddle.full([], 2) x1 = paddle.full([], 2)
x1.stop_gradient = False x1.stop_gradient = False
...@@ -3713,6 +3853,42 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3713,6 +3853,42 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[1].shape, (2, 3)) self.assertEqual(res[1].shape, (2, 3))
self.assertEqual(res[2].shape, (2,)) self.assertEqual(res[2].shape, (2,))
@prog_scope()
def test_gather_nd(self):
x1 = paddle.full([10], 1.0, 'float32')
x1.stop_gradient = False
x2 = paddle.full([2, 3], 1.0, 'float32')
x2.stop_gradient = False
index1 = paddle.full([1], 1, 'int64')
index2 = paddle.full([2], 1, 'int64')
out1 = paddle.gather_nd(x1, index1)
out2 = paddle.gather_nd(x2, index2)
paddle.static.append_backward(out1.sum())
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
out1.grad_name,
out2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
np.testing.assert_array_equal(res[0], 1.0)
np.testing.assert_array_equal(res[1], 1.0)
self.assertEqual(res[2].shape, (10,))
self.assertEqual(res[3].shape, (2, 3))
self.assertEqual(res[4].shape, ())
self.assertEqual(res[5].shape, ())
@prog_scope() @prog_scope()
def test_scatter_1D(self): def test_scatter_1D(self):
x = paddle.full([10], 1.0, 'float32') x = paddle.full([10], 1.0, 'float32')
...@@ -4329,6 +4505,50 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -4329,6 +4505,50 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res1[0].shape, (2, 3, 12, 12)) self.assertEqual(res1[0].shape, (2, 3, 12, 12))
self.assertEqual(res1[1].shape, (2, 3, 6, 6)) self.assertEqual(res1[1].shape, (2, 3, 6, 6))
@prog_scope()
def test_unstack(self):
x1 = paddle.full([1], 0, 'float32')
x1.stop_gradient = False
out1 = paddle.unstack(x1, 0)
out1 = paddle.add_n(out1)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (1,))
x2 = paddle.full([2], 2, 'float32')
x2.stop_gradient = False
out2 = paddle.unstack(x2, 0)
out2_sum = paddle.add_n(out2)
paddle.static.append_backward(out2_sum)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out2_sum, x2.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))
@prog_scope()
def test_unbind(self):
x1 = paddle.full([1], 0, 'float32')
x1.stop_gradient = False
out1 = paddle.unbind(x1, 0)
out1 = paddle.add_n(out1)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (1,))
x2 = paddle.full([2], 2, 'float32')
x2.stop_gradient = False
out2 = paddle.unbind(x2, 0)
out2_sum = paddle.add_n(out2)
paddle.static.append_backward(out2_sum)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out2_sum, x2.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))
@prog_scope() @prog_scope()
def test_maseked_select(self): def test_maseked_select(self):
x = paddle.rand([]) x = paddle.rand([])
...@@ -4345,6 +4565,34 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -4345,6 +4565,34 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1) self.assertEqual(res[3], 1)
@prog_scope()
def test_squeeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
out1 = paddle.squeeze(x1, axis=0)
paddle.static.append_backward(out1.sum())
x2 = paddle.full([], 3)
x3 = paddle.full([], 0, dtype='int32')
x2.stop_gradient = False
out2 = paddle.squeeze(x2, axis=x3)
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
@prog_scope() @prog_scope()
def test_unsqueeze(self): def test_unsqueeze(self):
x1 = paddle.full([], 2) x1 = paddle.full([], 2)
...@@ -4403,6 +4651,39 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -4403,6 +4651,39 @@ class TestSundryAPIStatic(unittest.TestCase):
res = self.exe.run(prog, feed={"x": x_tensor}, fetch_list=[out]) res = self.exe.run(prog, feed={"x": x_tensor}, fetch_list=[out])
self.assertEqual(res[0].shape, (3, 4, 2)) self.assertEqual(res[0].shape, (3, 4, 2))
@prog_scope()
def test_static_data(self):
x1 = paddle.static.data(name="x1", shape=[])
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
feed={
"x1": np.array(1.0, dtype='float32'),
},
fetch_list=[
x1.name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], np.array(1.0))
x2 = paddle.static.data(name="x2", shape=[])
x3 = paddle.static.data(name="x3", shape=[])
y = x2 + x3
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
feed={
"x2": 100.5,
"x3": 200.5,
},
fetch_list=[
y.name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 301.0)
@prog_scope() @prog_scope()
def test_prelu(self): def test_prelu(self):
x1 = paddle.full([], 1.0, 'float32') x1 = paddle.full([], 1.0, 'float32')
......
...@@ -966,8 +966,8 @@ def einsum(equation, *operands): ...@@ -966,8 +966,8 @@ def einsum(equation, *operands):
# dot # dot
print(paddle.einsum('i,i->', x, x)) print(paddle.einsum('i,i->', x, x))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [1.45936954]) # 1.45936954)
# outer # outer
print(paddle.einsum("i,j->ij", x, y)) print(paddle.einsum("i,j->ij", x, y))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册