未验证 提交 0a862fd3 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the precious of linspace Op using half way (#27452)

上级 fda54c02
......@@ -23,9 +23,16 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) {
out[index] = static_cast<T>(start + step * index);
__global__ void LinspaceKernel(T start, T stop, double step, int64_t size,
T* out) {
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] = static_cast<T>(start + step * index);
} else {
out[index] = static_cast<T>(stop - step * (size - index - 1));
}
}
}
......@@ -55,13 +62,15 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
framework::Tensor n;
framework::TensorCopy(start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0];
framework::TensorCopy(stop_t, platform::CPUPlace(), &n);
T stop = n.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
int32_t num = n.data<int32_t>()[0];
framework::Tensor n_start;
framework::Tensor n_stop;
framework::Tensor n_num;
framework::TensorCopy(start_t, platform::CPUPlace(), &n_start);
T start = n_start.data<T>()[0];
framework::TensorCopy(stop_t, platform::CPUPlace(), &n_stop);
T stop = n_stop.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n_num);
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]);
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"The num of linspace op should be larger "
......@@ -72,14 +81,16 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace());
double step = 0;
if (num != 1) {
step = (static_cast<double>(stop - start)) / (num - 1);
}
auto stream = context.cuda_device_context().stream();
int block = 512;
int grid = (num + block - 1) / block;
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, step, num, out_data);
if (num != 1) {
step = (static_cast<double>(stop - start)) / (num - 1);
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, stop, step, num,
out_data);
} else {
LinspaceSpecialKernel<T><<<grid, block, 0, stream>>>(start, out_data);
}
}
};
......
......@@ -56,9 +56,15 @@ class CPULinspaceKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace());
if (num > 1) {
// step should be of double type for all types
double step = (static_cast<double>(stop - start)) / (num - 1);
int half_num = num / 2;
for (int i = 0; i < num; ++i) {
out_data[i] = static_cast<T>(start + step * i);
if (i < half_num) {
out_data[i] = static_cast<T>(start + step * i);
} else {
out_data[i] = static_cast<T>(stop - step * (num - i - 1));
}
}
} else {
out_data[0] = static_cast<T>(start);
......
......@@ -1424,7 +1424,7 @@ def linspace(start, stop, num, dtype=None, name=None):
stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \
or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \
or a Tensor of shape [1] with data type int32 or int64.
or a Tensor of shape [1] with data type int32.
dtype(np.dtype|str, optional): The data type of output tensor, it could be
int32, int64, float32 and float64. Default: if None, the data type is float32.
name(str, optional): Normally there is no need for user to set this property.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册