未验证 提交 e073313d 编写于 作者: Y Yuanle Liu 提交者: GitHub

Fix arange gpu kernel (#49273)

上级 2259ced1
...@@ -57,10 +57,11 @@ void ArangeKernel(const Context& dev_ctx, ...@@ -57,10 +57,11 @@ void ArangeKernel(const Context& dev_ctx,
T* out_data = dev_ctx.template Alloc<T>(out); T* out_data = dev_ctx.template Alloc<T>(out);
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
int block = std::min(size, static_cast<int64_t>(256)); int64_t block = std::min(size, static_cast<int64_t>(256));
PADDLE_ENFORCE_NE( if (block == 0) {
block, 0, errors::OutOfRange("The value of block cannot be 0.")); return;
int grid = (size + block - 1) / block; }
int64_t grid = (size + block - 1) / block;
Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data); Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data);
} }
......
...@@ -70,6 +70,12 @@ class TestInt64ArangeOp(TestArangeOp): ...@@ -70,6 +70,12 @@ class TestInt64ArangeOp(TestArangeOp):
self.case = (-1, -10, -2) self.case = (-1, -10, -2)
class TestZeroSizeArangeOp(TestArangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (0, 0, 1)
class TestArangeOpError(unittest.TestCase): class TestArangeOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册