diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index 1cd59672f97fc36203ee01c17a8c81ea82e0ab12..e9375be1706eb462f4bbc12fafc034e9e5cdd68b 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -38,8 +38,11 @@ class LinspaceOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), tensor.layout()); + if (platform::is_xpu_place(tensor.place())) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + return expected_kernel_type; } }; diff --git a/paddle/phi/kernels/gpu/linspace_kernel.cu b/paddle/phi/kernels/gpu/linspace_kernel.cu index 3a6ff365c11db8fa4940cacb5fc75c5ebe50ebbb..66a3f833d276a9c1644f513f92af809140aa48a5 100644 --- a/paddle/phi/kernels/gpu/linspace_kernel.cu +++ b/paddle/phi/kernels/gpu/linspace_kernel.cu @@ -18,7 +18,6 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/copy_kernel.h" -#include "paddle/phi/kernels/funcs/data_type_transform.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -42,6 +41,47 @@ __global__ void LinspaceSpecialKernel(T start, T* out) { out[0] = static_cast(start); } +template +T GetValue(const Context& ctx, const DenseTensor& x) { + T value = static_cast(0); + if (x.place() != CPUPlace()) { + DenseTensor cpu_x; + Copy(ctx, x, CPUPlace(), true, &cpu_x); + value = cpu_x.data()[0]; + } else { + value = x.data()[0]; + } + return value; +} + +template +T GetValueOfExpectedType(const Context& ctx, const DenseTensor& x) { + switch (x.dtype()) { + case DataType::FLOAT32: + return static_cast(GetValue(ctx, x)); + case DataType::FLOAT64: + return static_cast(GetValue(ctx, x)); + case DataType::INT32: + return static_cast(GetValue(ctx, x)); + case DataType::INT64: + return static_cast(GetValue(ctx, x)); + case DataType::FLOAT16: + return static_cast(GetValue(ctx, x)); + case DataType::BFLOAT16: + return static_cast(GetValue(ctx, x)); + case DataType::BOOL: + return static_cast(GetValue(ctx, x)); + case DataType::INT16: + return static_cast(GetValue(ctx, x)); + case DataType::UINT8: + return static_cast(GetValue(ctx, x)); + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Data type (%s) is not supported when casting data type.", + x.dtype())); + } +} + template void LinspaceKernel(const Context& ctx, const DenseTensor& start, @@ -49,18 +89,9 @@ void LinspaceKernel(const Context& ctx, const DenseTensor& number, DataType dtype, DenseTensor* out) { - auto start_t = phi::funcs::TransDataType(ctx, start, dtype); - auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); - - DenseTensor n_start; - DenseTensor n_stop; - DenseTensor n_num; - phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start); - T start_data = n_start.data()[0]; - phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop); - T stop_data = n_stop.data()[0]; - phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num); - int64_t num = static_cast(n_num.data()[0]); + T start_value = GetValueOfExpectedType(ctx, start); + T stop_value = GetValueOfExpectedType(ctx, stop); + int64_t num = GetValueOfExpectedType(ctx, number); PADDLE_ENFORCE_GT( num, @@ -72,16 +103,15 @@ void LinspaceKernel(const Context& ctx, out->Resize(phi::make_ddim({num})); T* out_data = ctx.template Alloc(out); - double step = 0; auto stream = ctx.stream(); - int block = 512; - int grid = (num + block - 1) / block; if (num != 1) { - step = (static_cast(stop_data - start_data)) / (num - 1); + int block = 512; + int grid = (num + block - 1) / block; + double step = (static_cast(stop_value - start_value)) / (num - 1); LinspaceKernelInner<<>>( - start_data, stop_data, step, num, out_data); + start_value, stop_value, step, num, out_data); } else { - LinspaceSpecialKernel<<>>(start_data, out_data); + LinspaceSpecialKernel<<<1, 1, 0, stream>>>(start_value, out_data); } } @@ -94,4 +124,8 @@ PD_REGISTER_KERNEL(linspace, float, int32_t, int64_t, - double) {} + double) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index a5a4df6571b77f75dbf5d9c6723375bdf5a8c3d7..5163e6e5395bde1b1971538288c07c4613c85d0b 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -92,13 +92,13 @@ def linspace(start, stop, num, dtype=None, name=None): dtype = convert_np_dtype_to_dtype_(dtype) if not isinstance(start, Variable): with device_guard("cpu"): - tensor_start = fill_constant([1], dtype, start) + tensor_start = fill_constant([1], dtype, start, force_cpu=True) if not isinstance(stop, Variable): with device_guard("cpu"): - tensor_stop = fill_constant([1], dtype, stop) + tensor_stop = fill_constant([1], dtype, stop, force_cpu=True) if not isinstance(num, Variable): with device_guard("cpu"): - tensor_num = fill_constant([1], 'int32', num) + tensor_num = fill_constant([1], 'int32', num, force_cpu=True) if _non_static_mode(): return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype', dtype)