diff --git a/paddle/fluid/operators/linspace_op.cu b/paddle/fluid/operators/linspace_op.cu index c51e8785263b5de7a897f3865ed2dabdf93adfaa..a4f0693323297c286d24b169f1120e4017992a9b 100644 --- a/paddle/fluid/operators/linspace_op.cu +++ b/paddle/fluid/operators/linspace_op.cu @@ -23,9 +23,16 @@ namespace operators { using Tensor = framework::Tensor; template -__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) { - CUDA_KERNEL_LOOP(index, size) { - out[index] = static_cast(start + step * index); +__global__ void LinspaceKernel(T start, T stop, double step, int64_t size, + T* out) { + int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + + for (; index < size; index += blockDim.x * gridDim.x) { + if (index < size / 2) { + out[index] = static_cast(start + step * index); + } else { + out[index] = static_cast(stop - step * (size - index - 1)); + } } } @@ -55,13 +62,15 @@ class CUDALinspaceKernel : public framework::OpKernel { framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t); framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t); - framework::Tensor n; - framework::TensorCopy(start_t, platform::CPUPlace(), &n); - T start = n.data()[0]; - framework::TensorCopy(stop_t, platform::CPUPlace(), &n); - T stop = n.data()[0]; - framework::TensorCopy(*num_t, platform::CPUPlace(), &n); - int32_t num = n.data()[0]; + framework::Tensor n_start; + framework::Tensor n_stop; + framework::Tensor n_num; + framework::TensorCopy(start_t, platform::CPUPlace(), &n_start); + T start = n_start.data()[0]; + framework::TensorCopy(stop_t, platform::CPUPlace(), &n_stop); + T stop = n_stop.data()[0]; + framework::TensorCopy(*num_t, platform::CPUPlace(), &n_num); + int64_t num = static_cast(n_num.data()[0]); PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument( "The num of linspace op should be larger " @@ -72,14 +81,16 @@ class CUDALinspaceKernel : public framework::OpKernel { T* out_data = out->mutable_data(context.GetPlace()); double step = 0; - if (num != 1) { - step = (static_cast(stop - start)) / (num - 1); - } - auto stream = context.cuda_device_context().stream(); int block = 512; int grid = (num + block - 1) / block; - LinspaceKernel<<>>(start, step, num, out_data); + if (num != 1) { + step = (static_cast(stop - start)) / (num - 1); + LinspaceKernel<<>>(start, stop, step, num, + out_data); + } else { + LinspaceSpecialKernel<<>>(start, out_data); + } } }; diff --git a/paddle/fluid/operators/linspace_op.h b/paddle/fluid/operators/linspace_op.h index 2c30a66ef8e937127fb69a459a901164934b5b13..d8e0fefe175869171cac9c8d3798880e844dbe35 100644 --- a/paddle/fluid/operators/linspace_op.h +++ b/paddle/fluid/operators/linspace_op.h @@ -56,9 +56,15 @@ class CPULinspaceKernel : public framework::OpKernel { T* out_data = out->mutable_data(context.GetPlace()); if (num > 1) { + // step should be of double type for all types double step = (static_cast(stop - start)) / (num - 1); + int half_num = num / 2; for (int i = 0; i < num; ++i) { - out_data[i] = static_cast(start + step * i); + if (i < half_num) { + out_data[i] = static_cast(start + step * i); + } else { + out_data[i] = static_cast(stop - step * (num - i - 1)); + } } } else { out_data[0] = static_cast(start); diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 0ce7c098e2d53cbcea0f491a6a816c388d14ea4b..cf52f3b00fb2739d186021dc51d6aa0f506be706 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1424,7 +1424,7 @@ def linspace(start, stop, num, dtype=None, name=None): stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \ or a Tensor of shape [1] with input data type int32, int64, float32 or float64. num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \ - or a Tensor of shape [1] with data type int32 or int64. + or a Tensor of shape [1] with data type int32. dtype(np.dtype|str, optional): The data type of output tensor, it could be int32, int64, float32 and float64. Default: if None, the data type is float32. name(str, optional): Normally there is no need for user to set this property.