未验证 提交 336bc20b 编写于 作者: Z zhangyikun02 提交者: GitHub

tile op support 0D input for xpu (#53237)

上级 c7c5635e
......@@ -31,13 +31,15 @@ void TileKernel(const Context& dev_ctx,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto rank = x.dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
errors::InvalidArgument(
"The rank of the input 'x' for tile op must be a positive "
"integer, but the value received is %d.",
rank));
std::vector<int64_t> repeat_times = repeat_times_arr.GetData();
int repeat_times_size = repeat_times.size();
rank = std::max(rank, repeat_times_size);
PADDLE_ENFORCE_GE(rank,
0,
errors::InvalidArgument(
"The rank of the input 'x' for tile op must be a >=0 "
"integer, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
MAX_RANK_SUPPORTED,
......@@ -46,14 +48,12 @@ void TileKernel(const Context& dev_ctx,
"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
rank));
std::vector<int64_t> repeat_times = repeat_times_arr.GetData();
int repeat_times_size = repeat_times.size();
PADDLE_ENFORCE_GE(
repeat_times_size,
1,
0,
errors::InvalidArgument(
"The number of elements of the input 'repeat_times' for tile "
"op must be positive, but the value received is %d.",
"op must be >=0, but the value received is %d.",
repeat_times_size));
PADDLE_ENFORCE_LE(
repeat_times_size,
......@@ -102,20 +102,15 @@ void TileKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out);
std::vector<int64_t> temp(repeat_times.size(), 1);
if (repeat_times == temp) {
if (rank == 0 || repeat_times == temp) {
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
if (std::is_same<T, double>::value) {
int r = xpu::copy(dev_ctx.x_context(),
reinterpret_cast<const int8_t*>(x.data<double>()),
reinterpret_cast<int8_t*>(out->data<double>()),
8 * x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
} else {
int r = xpu::copy(
dev_ctx.x_context(), x.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
int64_t count = x.numel() * sizeof(T);
int r = xpu::copy(dev_ctx.x_context(),
reinterpret_cast<const int8_t*>(x.data<T>()),
reinterpret_cast<int8_t*>(out->data<T>()),
count);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
......
......@@ -90,6 +90,21 @@ class XPUTestTileOpRank1(XPUOpTestWrapper):
self.ori_shape = (2, 4, 5, 7)
self.repeat_times = (3, 2, 1, 2)
class TestTileOpRank_ZeroDim1(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = []
class TestTileOpRank_ZeroDim2(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = [2]
class TestTileOpRank_ZeroDim3(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = [2, 3]
# Situation 2: repeat_times is a list (with tensor)
class XPUTestTileOpRank1_tensor_attr(XPUOpTestWrapper):
......@@ -209,5 +224,36 @@ class TestTileAPI(unittest.TestCase):
assert np.array_equal(out_3.numpy(), np.tile(np_x, (2, 3)))
class TestTileAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
x = paddle.rand([])
x.stop_gradient = False
out = paddle.tile(x, [])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
out = paddle.tile(x, [3])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [3])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [3])
out = paddle.tile(x, [2, 3])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [2, 3])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [2, 3])
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册