未验证 提交 336bc20b 编写于 作者: Z zhangyikun02 提交者: GitHub

tile op support 0D input for xpu (#53237)

上级 c7c5635e
...@@ -31,11 +31,13 @@ void TileKernel(const Context& dev_ctx, ...@@ -31,11 +31,13 @@ void TileKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
auto rank = x.dims().size(); auto rank = x.dims().size();
PADDLE_ENFORCE_GE( std::vector<int64_t> repeat_times = repeat_times_arr.GetData();
rank, int repeat_times_size = repeat_times.size();
1, rank = std::max(rank, repeat_times_size);
PADDLE_ENFORCE_GE(rank,
0,
errors::InvalidArgument( errors::InvalidArgument(
"The rank of the input 'x' for tile op must be a positive " "The rank of the input 'x' for tile op must be a >=0 "
"integer, but the value received is %d.", "integer, but the value received is %d.",
rank)); rank));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
...@@ -46,14 +48,12 @@ void TileKernel(const Context& dev_ctx, ...@@ -46,14 +48,12 @@ void TileKernel(const Context& dev_ctx,
"must be less than or equal to %d, but the value received is %d.", "must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, MAX_RANK_SUPPORTED,
rank)); rank));
std::vector<int64_t> repeat_times = repeat_times_arr.GetData();
int repeat_times_size = repeat_times.size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
repeat_times_size, repeat_times_size,
1, 0,
errors::InvalidArgument( errors::InvalidArgument(
"The number of elements of the input 'repeat_times' for tile " "The number of elements of the input 'repeat_times' for tile "
"op must be positive, but the value received is %d.", "op must be >=0, but the value received is %d.",
repeat_times_size)); repeat_times_size));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
repeat_times_size, repeat_times_size,
...@@ -102,20 +102,15 @@ void TileKernel(const Context& dev_ctx, ...@@ -102,20 +102,15 @@ void TileKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
std::vector<int64_t> temp(repeat_times.size(), 1); std::vector<int64_t> temp(repeat_times.size(), 1);
if (repeat_times == temp) { if (rank == 0 || repeat_times == temp) {
out->Resize(x.dims()); out->Resize(x.dims());
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (std::is_same<T, double>::value) { int64_t count = x.numel() * sizeof(T);
int r = xpu::copy(dev_ctx.x_context(), int r = xpu::copy(dev_ctx.x_context(),
reinterpret_cast<const int8_t*>(x.data<double>()), reinterpret_cast<const int8_t*>(x.data<T>()),
reinterpret_cast<int8_t*>(out->data<double>()), reinterpret_cast<int8_t*>(out->data<T>()),
8 * x.numel()); count);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
} else {
int r = xpu::copy(
dev_ctx.x_context(), x.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
return; return;
} }
......
...@@ -90,6 +90,21 @@ class XPUTestTileOpRank1(XPUOpTestWrapper): ...@@ -90,6 +90,21 @@ class XPUTestTileOpRank1(XPUOpTestWrapper):
self.ori_shape = (2, 4, 5, 7) self.ori_shape = (2, 4, 5, 7)
self.repeat_times = (3, 2, 1, 2) self.repeat_times = (3, 2, 1, 2)
class TestTileOpRank_ZeroDim1(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = []
class TestTileOpRank_ZeroDim2(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = [2]
class TestTileOpRank_ZeroDim3(TestTileOpRank1):
def init_data(self):
self.ori_shape = []
self.repeat_times = [2, 3]
# Situation 2: repeat_times is a list (with tensor) # Situation 2: repeat_times is a list (with tensor)
class XPUTestTileOpRank1_tensor_attr(XPUOpTestWrapper): class XPUTestTileOpRank1_tensor_attr(XPUOpTestWrapper):
...@@ -209,5 +224,36 @@ class TestTileAPI(unittest.TestCase): ...@@ -209,5 +224,36 @@ class TestTileAPI(unittest.TestCase):
assert np.array_equal(out_3.numpy(), np.tile(np_x, (2, 3))) assert np.array_equal(out_3.numpy(), np.tile(np_x, (2, 3)))
class TestTileAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
x = paddle.rand([])
x.stop_gradient = False
out = paddle.tile(x, [])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
out = paddle.tile(x, [3])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [3])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [3])
out = paddle.tile(x, [2, 3])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [2, 3])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [2, 3])
paddle.enable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册