未验证 提交 417e5baf 编写于 作者: C chenxujun 提交者: GitHub

Add logspace tests (#52956)

上级 cb81befa
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/logspace_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
......@@ -25,25 +26,34 @@ namespace phi {
template <typename T>
__global__ void LogspaceKernelInner(
T start, T stop, double step, T base, int64_t size, T* out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType mt_start = static_cast<MPType>(start);
MPType mt_stop = static_cast<MPType>(stop);
MPType mt_base = static_cast<MPType>(base);
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] =
static_cast<T>(pow(static_cast<double>(base),
static_cast<double>(start + step * index)));
static_cast<T>(pow(static_cast<double>(mt_base),
static_cast<double>(mt_start + step * index)));
} else {
out[index] = static_cast<T>(
pow(static_cast<double>(base),
static_cast<double>(stop - step * (size - index - 1))));
pow(static_cast<double>(mt_base),
static_cast<double>(mt_stop - step * (size - index - 1))));
}
}
}
template <typename T>
__global__ void LogspaceSpecialKernel(T start, T base, T* out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType mt_start = static_cast<MPType>(start);
MPType mt_base = static_cast<MPType>(base);
out[0] = static_cast<T>(
pow(static_cast<double>(base), static_cast<double>(start)));
pow(static_cast<double>(mt_base), static_cast<double>(mt_start)));
}
template <typename T, typename Context>
......@@ -54,6 +64,8 @@ void LogspaceKernel(const Context& ctx,
const DenseTensor& base,
DataType dtype,
DenseTensor* out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto start_t = phi::funcs::TransDataType(ctx, start, dtype);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype);
auto base_t = phi::funcs::TransDataType(ctx, base, dtype);
......@@ -71,6 +83,9 @@ void LogspaceKernel(const Context& ctx,
phi::Copy(ctx, base_t, phi::CPUPlace(), false, &n_base);
T base_data = n_base.data<T>()[0];
MPType mt_start_data = static_cast<MPType>(start_data);
MPType mt_stop_data = static_cast<MPType>(stop_data);
PADDLE_ENFORCE_GT(
num,
0,
......@@ -86,7 +101,7 @@ void LogspaceKernel(const Context& ctx,
int block = 512;
int grid = (num + block - 1) / block;
if (num != 1) {
step = (static_cast<double>(stop_data - start_data)) / (num - 1);
step = (static_cast<double>(mt_stop_data - mt_start_data)) / (num - 1);
LogspaceKernelInner<T><<<grid, block, 0, stream>>>(
start_data, stop_data, step, base_data, num, out_data);
} else {
......@@ -104,4 +119,6 @@ PD_REGISTER_KERNEL(logspace,
float,
int32_t,
int64_t,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,9 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
class TestLogspaceOpCommonCase(OpTest):
......@@ -41,6 +42,54 @@ class TestLogspaceOpCommonCase(OpTest):
self.check_output()
class TestLogspaceFP16Op(TestLogspaceOpCommonCase):
def init_data(self):
self.dtype = np.float16
self.inputs = {
'Start': np.array([0]).astype(self.dtype),
'Stop': np.array([10]).astype(self.dtype),
'Num': np.array([11]).astype('int32'),
'Base': np.array([2]).astype(self.dtype),
}
self.attrs = {'dtype': int(paddle.float16)}
self.outputs = {'Out': np.power(2, np.arange(0, 11)).astype(self.dtype)}
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestLogspaceBF16Op(OpTest):
def setUp(self):
self.op_type = "logspace"
self.python_api = paddle.logspace
self.init_data()
def init_data(self):
self.dtype = np.uint16
self.np_dtype = np.float32
self.inputs = {
'Start': np.array([0]).astype(self.np_dtype),
'Stop': np.array([10]).astype(self.np_dtype),
'Num': np.array([11]).astype('int32'),
'Base': np.array([2]).astype(self.np_dtype),
}
self.attrs = {'dtype': int(paddle.bfloat16)}
self.outputs = {
'Out': np.power(2, np.arange(0, 11)).astype(self.np_dtype)
}
self.inputs["Start"] = convert_float_to_uint16(self.inputs["Start"])
self.inputs["Stop"] = convert_float_to_uint16(self.inputs["Stop"])
self.inputs["Base"] = convert_float_to_uint16(self.inputs["Base"])
self.outputs["Out"] = convert_float_to_uint16(self.outputs["Out"])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
class TestLogspaceOpReverseCase(TestLogspaceOpCommonCase):
def init_data(self):
dtype = 'float32'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册