未验证 提交 34cda80b 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize linspace to avoid GPU -> CPU copy. (#42750)

上级 5924458b
...@@ -38,8 +38,11 @@ class LinspaceOp : public framework::OperatorWithKernel { ...@@ -38,8 +38,11 @@ class LinspaceOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor, const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_, if (platform::is_xpu_place(tensor.place())) {
tensor.place(), tensor.layout()); return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
return expected_kernel_type;
} }
}; };
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -42,6 +41,47 @@ __global__ void LinspaceSpecialKernel(T start, T* out) { ...@@ -42,6 +41,47 @@ __global__ void LinspaceSpecialKernel(T start, T* out) {
out[0] = static_cast<T>(start); out[0] = static_cast<T>(start);
} }
template <typename T, typename Context>
T GetValue(const Context& ctx, const DenseTensor& x) {
T value = static_cast<T>(0);
if (x.place() != CPUPlace()) {
DenseTensor cpu_x;
Copy(ctx, x, CPUPlace(), true, &cpu_x);
value = cpu_x.data<T>()[0];
} else {
value = x.data<T>()[0];
}
return value;
}
template <typename T, typename Context>
T GetValueOfExpectedType(const Context& ctx, const DenseTensor& x) {
switch (x.dtype()) {
case DataType::FLOAT32:
return static_cast<T>(GetValue<float, Context>(ctx, x));
case DataType::FLOAT64:
return static_cast<T>(GetValue<double, Context>(ctx, x));
case DataType::INT32:
return static_cast<T>(GetValue<int32_t, Context>(ctx, x));
case DataType::INT64:
return static_cast<T>(GetValue<int64_t, Context>(ctx, x));
case DataType::FLOAT16:
return static_cast<T>(GetValue<phi::dtype::float16, Context>(ctx, x));
case DataType::BFLOAT16:
return static_cast<T>(GetValue<phi::dtype::bfloat16, Context>(ctx, x));
case DataType::BOOL:
return static_cast<T>(GetValue<bool, Context>(ctx, x));
case DataType::INT16:
return static_cast<T>(GetValue<int16_t, Context>(ctx, x));
case DataType::UINT8:
return static_cast<T>(GetValue<uint8_t, Context>(ctx, x));
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
x.dtype()));
}
}
template <typename T, typename Context> template <typename T, typename Context>
void LinspaceKernel(const Context& ctx, void LinspaceKernel(const Context& ctx,
const DenseTensor& start, const DenseTensor& start,
...@@ -49,18 +89,9 @@ void LinspaceKernel(const Context& ctx, ...@@ -49,18 +89,9 @@ void LinspaceKernel(const Context& ctx,
const DenseTensor& number, const DenseTensor& number,
DataType dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
auto start_t = phi::funcs::TransDataType(ctx, start, dtype); T start_value = GetValueOfExpectedType<T, Context>(ctx, start);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); T stop_value = GetValueOfExpectedType<T, Context>(ctx, stop);
int64_t num = GetValueOfExpectedType<int64_t, Context>(ctx, number);
DenseTensor n_start;
DenseTensor n_stop;
DenseTensor n_num;
phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start);
T start_data = n_start.data<T>()[0];
phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop);
T stop_data = n_stop.data<T>()[0];
phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num);
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]);
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
num, num,
...@@ -72,16 +103,15 @@ void LinspaceKernel(const Context& ctx, ...@@ -72,16 +103,15 @@ void LinspaceKernel(const Context& ctx,
out->Resize(phi::make_ddim({num})); out->Resize(phi::make_ddim({num}));
T* out_data = ctx.template Alloc<T>(out); T* out_data = ctx.template Alloc<T>(out);
double step = 0;
auto stream = ctx.stream(); auto stream = ctx.stream();
int block = 512;
int grid = (num + block - 1) / block;
if (num != 1) { if (num != 1) {
step = (static_cast<double>(stop_data - start_data)) / (num - 1); int block = 512;
int grid = (num + block - 1) / block;
double step = (static_cast<double>(stop_value - start_value)) / (num - 1);
LinspaceKernelInner<T><<<grid, block, 0, stream>>>( LinspaceKernelInner<T><<<grid, block, 0, stream>>>(
start_data, stop_data, step, num, out_data); start_value, stop_value, step, num, out_data);
} else { } else {
LinspaceSpecialKernel<T><<<grid, block, 0, stream>>>(start_data, out_data); LinspaceSpecialKernel<T><<<1, 1, 0, stream>>>(start_value, out_data);
} }
} }
...@@ -94,4 +124,8 @@ PD_REGISTER_KERNEL(linspace, ...@@ -94,4 +124,8 @@ PD_REGISTER_KERNEL(linspace,
float, float,
int32_t, int32_t,
int64_t, int64_t,
double) {} double) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -92,13 +92,13 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -92,13 +92,13 @@ def linspace(start, stop, num, dtype=None, name=None):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable): if not isinstance(start, Variable):
with device_guard("cpu"): with device_guard("cpu"):
tensor_start = fill_constant([1], dtype, start) tensor_start = fill_constant([1], dtype, start, force_cpu=True)
if not isinstance(stop, Variable): if not isinstance(stop, Variable):
with device_guard("cpu"): with device_guard("cpu"):
tensor_stop = fill_constant([1], dtype, stop) tensor_stop = fill_constant([1], dtype, stop, force_cpu=True)
if not isinstance(num, Variable): if not isinstance(num, Variable):
with device_guard("cpu"): with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num) tensor_num = fill_constant([1], 'int32', num, force_cpu=True)
if _non_static_mode(): if _non_static_mode():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype', return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype) dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册