未验证 提交 0a862fd3 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the precious of linspace Op using half way (#27452)

上级 fda54c02
...@@ -23,9 +23,16 @@ namespace operators { ...@@ -23,9 +23,16 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) { __global__ void LinspaceKernel(T start, T stop, double step, int64_t size,
CUDA_KERNEL_LOOP(index, size) { T* out) {
out[index] = static_cast<T>(start + step * index); int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] = static_cast<T>(start + step * index);
} else {
out[index] = static_cast<T>(stop - step * (size - index - 1));
}
} }
} }
...@@ -55,13 +62,15 @@ class CUDALinspaceKernel : public framework::OpKernel<T> { ...@@ -55,13 +62,15 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t); framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t); framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
framework::Tensor n; framework::Tensor n_start;
framework::TensorCopy(start_t, platform::CPUPlace(), &n); framework::Tensor n_stop;
T start = n.data<T>()[0]; framework::Tensor n_num;
framework::TensorCopy(stop_t, platform::CPUPlace(), &n); framework::TensorCopy(start_t, platform::CPUPlace(), &n_start);
T stop = n.data<T>()[0]; T start = n_start.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n); framework::TensorCopy(stop_t, platform::CPUPlace(), &n_stop);
int32_t num = n.data<int32_t>()[0]; T stop = n_stop.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n_num);
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]);
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"The num of linspace op should be larger " "The num of linspace op should be larger "
...@@ -72,14 +81,16 @@ class CUDALinspaceKernel : public framework::OpKernel<T> { ...@@ -72,14 +81,16 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
double step = 0; double step = 0;
if (num != 1) {
step = (static_cast<double>(stop - start)) / (num - 1);
}
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
int block = 512; int block = 512;
int grid = (num + block - 1) / block; int grid = (num + block - 1) / block;
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, step, num, out_data); if (num != 1) {
step = (static_cast<double>(stop - start)) / (num - 1);
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, stop, step, num,
out_data);
} else {
LinspaceSpecialKernel<T><<<grid, block, 0, stream>>>(start, out_data);
}
} }
}; };
......
...@@ -56,9 +56,15 @@ class CPULinspaceKernel : public framework::OpKernel<T> { ...@@ -56,9 +56,15 @@ class CPULinspaceKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
if (num > 1) { if (num > 1) {
// step should be of double type for all types
double step = (static_cast<double>(stop - start)) / (num - 1); double step = (static_cast<double>(stop - start)) / (num - 1);
int half_num = num / 2;
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
out_data[i] = static_cast<T>(start + step * i); if (i < half_num) {
out_data[i] = static_cast<T>(start + step * i);
} else {
out_data[i] = static_cast<T>(stop - step * (num - i - 1));
}
} }
} else { } else {
out_data[0] = static_cast<T>(start); out_data[0] = static_cast<T>(start);
......
...@@ -1424,7 +1424,7 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1424,7 +1424,7 @@ def linspace(start, stop, num, dtype=None, name=None):
stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \ stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \
or a Tensor of shape [1] with input data type int32, int64, float32 or float64. or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \ num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \
or a Tensor of shape [1] with data type int32 or int64. or a Tensor of shape [1] with data type int32.
dtype(np.dtype|str, optional): The data type of output tensor, it could be dtype(np.dtype|str, optional): The data type of output tensor, it could be
int32, int64, float32 and float64. Default: if None, the data type is float32. int32, int64, float32 and float64. Default: if None, the data type is float32.
name(str, optional): Normally there is no need for user to set this property. name(str, optional): Normally there is no need for user to set this property.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册