未验证 提交 78add057 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support 0D for paddle.transpose/reshape/stack/tile/unsqueeze (#46555)

上级 19438131
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -32,6 +33,9 @@ inline void TransCompute(const int dim,
phi::DenseTensor* out,
const std::vector<int>& axis) {
switch (dim) {
case 0:
phi::Copy<DeviceContext>(dev_ctx, in, dev_ctx.GetPlace(), false, out);
break;
case 1:
phi::funcs::Transpose<DeviceContext, T, 1> trans1;
trans1(dev_ctx, in, out, axis);
......
......@@ -3713,7 +3713,7 @@ void TileInferMeta(const MetaTensor& x,
repeat_times_data.size()));
PADDLE_ENFORCE_GE(
repeat_times_data.size(),
1,
0,
errors::InvalidArgument(
"The size of the shape of input 'repeat_times' for tile op "
"must be positive integers, but the value received is %d.",
......@@ -3746,7 +3746,7 @@ void TileInferMeta(const MetaTensor& x,
}
out->set_dims(phi::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) {
if (out_rank > 0 && (out_shape[0] == x_dims[0])) {
out->share_lod(x);
}
out->set_dtype(x.dtype());
......
......@@ -35,6 +35,9 @@ void TransposeKernel(const Context& ctx,
}
int rank = axis.size();
switch (rank) {
case 0:
phi::Copy<Context>(ctx, x, ctx.GetPlace(), false, out);
break;
case 1:
funcs::Transpose<Context, T, 1> trans1;
trans1(ctx, x, out, axis);
......
......@@ -35,6 +35,10 @@ void TransposeKernel(const Context& ctx,
if (out->numel() == 0) {
return;
}
if (axis.size() == 0) {
phi::Copy<Context>(ctx, x, ctx.GetPlace(), false, out);
return;
}
paddle::operators::TransposeGPUKernelDriver<T>(ctx, x, axis, out);
}
......
......@@ -54,6 +54,10 @@ void Tile(const Context& dev_ctx,
vec_x_dims.size(),
repeat_times.size()));
if (Rank == 0) {
phi::Copy<DeviceContext>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
return;
}
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
......@@ -71,6 +75,7 @@ void Tile(const Context& dev_ctx,
auto eigen_out = EigenTensor<T, Rank>::From(*out, out_dims);
auto& place = *dev_ctx.eigen_device();
// use 32-bit index to speed up
bool use_32bit_index = eigen_out.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
......@@ -93,6 +98,9 @@ void TileKernel(const Context& dev_ctx,
rank = std::max(rank, repeat_times_size);
switch (rank) {
case 0:
Tile<Context, T, 0>(dev_ctx, x, repeat_times_data, out);
break;
case 1:
Tile<Context, T, 1>(dev_ctx, x, repeat_times_data, out);
break;
......
......@@ -78,5 +78,15 @@ class TestCase2(TestBase):
self.attrs = {"perm": [4, 0, 2, 3, 1]}
class TestCase_ZeroDim(TestBase):
def set_data_feed(self):
data = np.random.uniform(size=[])
self.feed_fp32 = {"x": data.astype(np.float32)}
def set_op_attrs(self):
self.attrs = {"perm": []}
if __name__ == "__main__":
unittest.main()
......@@ -54,6 +54,12 @@ class TestTransposeOp(OpTest):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestCase_ZeroDim(TestTransposeOp):
def init_shape_axis(self):
self.shape = ()
self.axis = ()
class TestCase0(TestTransposeOp):
def init_shape_axis(self):
......
......@@ -46,6 +46,30 @@ class TestReshapeOp(OpTest):
self.check_grad(["X"], "Out")
class TestReshapeOp_ZeroDim1(OpTest):
def init_data(self):
self.ori_shape = ()
self.new_shape = (1)
self.infered_shape = (1)
class TestReshapeOp_ZeroDim2(OpTest):
def init_data(self):
self.ori_shape = (1)
self.new_shape = ()
self.infered_shape = ()
class TestReshapeOp_ZeroDim3(OpTest):
def init_data(self):
self.ori_shape = ()
self.new_shape = (-1)
self.infered_shape = (1)
class TestReshapeBF16Op(OpTest):
def setUp(self):
......@@ -526,6 +550,58 @@ class TestReshapeZeroTensor(unittest.TestCase):
zero_tensor.reshape([2, 3])
class TestReshapeAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
x = paddle.rand([])
x.stop_gradient = False
out = paddle.reshape(x, [1])
out.backward()
self.assertEqual(out.shape, [1])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [1])
out = paddle.reshape(x, [-1, 1])
out.backward()
self.assertEqual(out.shape, [1, 1])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [1, 1])
paddle.enable_static()
def test_static(self):
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.reshape(x, [-1])
fluid.backward.append_backward(out)
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(fluid.framework.grad_var_name(x.name))
out_grad = block.var(fluid.framework.grad_var_name(out.name))
# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, (1, ))
self.assertEqual(x_grad.shape, ())
self.assertEqual(out_grad.shape, (1, ))
exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out, x_grad, out_grad])
# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, (1, ))
self.assertEqual(result[2].shape, ())
self.assertEqual(result[3].shape, (1, ))
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -19,6 +19,8 @@ import paddle.fluid as fluid
from op_test import OpTest, convert_float_to_uint16
from paddle.fluid.framework import Program, program_guard
paddle.enable_static()
class TestStackOpBase(OpTest):
......@@ -99,6 +101,12 @@ class TestStackOp6(TestStackOpBase):
self.axis = 3
class TestStackOp_ZeroDim(TestStackOpBase):
def initParameters(self):
self.input_dim = ()
class TestStackBF16Op(OpTest):
def initDefaultParameters(self):
......@@ -293,5 +301,26 @@ class TestStackOpWithNegativeShape(unittest.TestCase):
rtol=1e-05)
class TestStackAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
x1 = paddle.rand([])
x2 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
out = paddle.stack([x1, x2])
out.backward()
self.assertEqual(out.shape, [2])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(out.grad.shape, [2])
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
......@@ -46,6 +46,27 @@ class TestTileOpRank1(OpTest):
self.check_grad(['X'], 'Out')
class TestTileOpRank_ZeroDim1(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = []
class TestTileOpRank_ZeroDim2(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = [2]
class TestTileOpRank_ZeroDim3(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = [2, 3]
# with dimension expanding
class TestTileOpRank2Expanding(TestTileOpRank1):
......@@ -338,6 +359,36 @@ class TestTileTripleGradCheck(unittest.TestCase):
self.func(p)
class TestTileAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
x = paddle.rand([])
x.stop_gradient = False
out = paddle.tile(x, [])
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
out = paddle.tile(x, [3])
out.backward()
self.assertEqual(out.shape, [3])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [3])
out = paddle.tile(x, [2, 3])
out.backward()
self.assertEqual(out.shape, [2, 3])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [2, 3])
paddle.enable_static()
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -127,6 +127,13 @@ class TestCase9(TestTransposeOp):
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
class TestCase_ZeroDim(TestTransposeOp):
def initTestCase(self):
self.shape = ()
self.axis = ()
class TestAutoTuneTransposeOp(OpTest):
def setUp(self):
......@@ -601,6 +608,24 @@ class TestTransposeTripleGradCheck(unittest.TestCase):
self.func(p)
class TestTransposeAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
x = paddle.rand([])
x.stop_gradient = False
out = paddle.transpose(x, [])
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
paddle.enable_static()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -89,6 +89,30 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
self.new_shape = (10, 1, 1, 2, 5, 1)
class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (-1, )
self.new_shape = (1)
class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (-1, 1)
self.new_shape = (1, 1)
class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (0, 1, 2)
self.new_shape = (1, 1, 1)
# axes is a list(with tensor)
class TestUnsqueezeOp_AxesTensorList(OpTest):
......@@ -284,5 +308,35 @@ class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI):
self.unsqueeze = paddle.unsqueeze_
class TestUnsqueezeAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
x = paddle.rand([])
x.stop_gradient = False
out = paddle.unsqueeze(x, [-1])
out.backward()
self.assertEqual(out.shape, [1])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [1])
out = paddle.unsqueeze(x, [-1, 1])
out.backward()
self.assertEqual(out.shape, [1, 1])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [1, 1])
out = paddle.unsqueeze(x, [0, 1, 2])
out.backward()
self.assertEqual(out.shape, [1, 1, 1])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [1, 1, 1])
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -115,6 +115,30 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
self.new_shape = (10, 1, 1, 2, 5, 1)
class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (-1, )
self.new_shape = (1)
class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (-1, 1)
self.new_shape = (1, 1)
class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (0, 1, 2)
self.new_shape = (1, 1, 1)
class API_TestUnsqueeze(unittest.TestCase):
def test_out(self):
......
......@@ -60,6 +60,13 @@ class TestXPUTransposeOp(XPUOpTest):
self.axis = (1, 0)
class TestCase_ZeroDim(TestXPUTransposeOp):
def initTestCase(self):
self.shape = ()
self.axis = ()
class TestCase0(TestXPUTransposeOp):
def initTestCase(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册