Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
34cda80b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
34cda80b
编写于
5月 16, 2022
作者:
Y
Yiqun Liu
提交者:
GitHub
5月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize linspace to avoid GPU -> CPU copy. (#42750)
上级
5924458b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
62 addition
and
25 deletion
+62
-25
paddle/fluid/operators/linspace_op.cc
paddle/fluid/operators/linspace_op.cc
+5
-2
paddle/phi/kernels/gpu/linspace_kernel.cu
paddle/phi/kernels/gpu/linspace_kernel.cu
+54
-20
python/paddle/tensor/creation.py
python/paddle/tensor/creation.py
+3
-3
未找到文件。
paddle/fluid/operators/linspace_op.cc
浏览文件 @
34cda80b
...
...
@@ -38,8 +38,11 @@ class LinspaceOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
if
(
platform
::
is_xpu_place
(
tensor
.
place
()))
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
return
expected_kernel_type
;
}
};
...
...
paddle/phi/kernels/gpu/linspace_kernel.cu
浏览文件 @
34cda80b
...
...
@@ -18,7 +18,6 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
...
...
@@ -42,6 +41,47 @@ __global__ void LinspaceSpecialKernel(T start, T* out) {
out
[
0
]
=
static_cast
<
T
>
(
start
);
}
template
<
typename
T
,
typename
Context
>
T
GetValue
(
const
Context
&
ctx
,
const
DenseTensor
&
x
)
{
T
value
=
static_cast
<
T
>
(
0
);
if
(
x
.
place
()
!=
CPUPlace
())
{
DenseTensor
cpu_x
;
Copy
(
ctx
,
x
,
CPUPlace
(),
true
,
&
cpu_x
);
value
=
cpu_x
.
data
<
T
>
()[
0
];
}
else
{
value
=
x
.
data
<
T
>
()[
0
];
}
return
value
;
}
template
<
typename
T
,
typename
Context
>
T
GetValueOfExpectedType
(
const
Context
&
ctx
,
const
DenseTensor
&
x
)
{
switch
(
x
.
dtype
())
{
case
DataType
::
FLOAT32
:
return
static_cast
<
T
>
(
GetValue
<
float
,
Context
>
(
ctx
,
x
));
case
DataType
::
FLOAT64
:
return
static_cast
<
T
>
(
GetValue
<
double
,
Context
>
(
ctx
,
x
));
case
DataType
::
INT32
:
return
static_cast
<
T
>
(
GetValue
<
int32_t
,
Context
>
(
ctx
,
x
));
case
DataType
::
INT64
:
return
static_cast
<
T
>
(
GetValue
<
int64_t
,
Context
>
(
ctx
,
x
));
case
DataType
::
FLOAT16
:
return
static_cast
<
T
>
(
GetValue
<
phi
::
dtype
::
float16
,
Context
>
(
ctx
,
x
));
case
DataType
::
BFLOAT16
:
return
static_cast
<
T
>
(
GetValue
<
phi
::
dtype
::
bfloat16
,
Context
>
(
ctx
,
x
));
case
DataType
::
BOOL
:
return
static_cast
<
T
>
(
GetValue
<
bool
,
Context
>
(
ctx
,
x
));
case
DataType
::
INT16
:
return
static_cast
<
T
>
(
GetValue
<
int16_t
,
Context
>
(
ctx
,
x
));
case
DataType
::
UINT8
:
return
static_cast
<
T
>
(
GetValue
<
uint8_t
,
Context
>
(
ctx
,
x
));
default:
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when casting data type."
,
x
.
dtype
()));
}
}
template
<
typename
T
,
typename
Context
>
void
LinspaceKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
start
,
...
...
@@ -49,18 +89,9 @@ void LinspaceKernel(const Context& ctx,
const
DenseTensor
&
number
,
DataType
dtype
,
DenseTensor
*
out
)
{
auto
start_t
=
phi
::
funcs
::
TransDataType
(
ctx
,
start
,
dtype
);
auto
stop_t
=
phi
::
funcs
::
TransDataType
(
ctx
,
stop
,
dtype
);
DenseTensor
n_start
;
DenseTensor
n_stop
;
DenseTensor
n_num
;
phi
::
Copy
(
ctx
,
start_t
,
phi
::
CPUPlace
(),
false
,
&
n_start
);
T
start_data
=
n_start
.
data
<
T
>
()[
0
];
phi
::
Copy
(
ctx
,
stop_t
,
phi
::
CPUPlace
(),
false
,
&
n_stop
);
T
stop_data
=
n_stop
.
data
<
T
>
()[
0
];
phi
::
Copy
(
ctx
,
number
,
phi
::
CPUPlace
(),
false
,
&
n_num
);
int64_t
num
=
static_cast
<
int64_t
>
(
n_num
.
data
<
int32_t
>
()[
0
]);
T
start_value
=
GetValueOfExpectedType
<
T
,
Context
>
(
ctx
,
start
);
T
stop_value
=
GetValueOfExpectedType
<
T
,
Context
>
(
ctx
,
stop
);
int64_t
num
=
GetValueOfExpectedType
<
int64_t
,
Context
>
(
ctx
,
number
);
PADDLE_ENFORCE_GT
(
num
,
...
...
@@ -72,16 +103,15 @@ void LinspaceKernel(const Context& ctx,
out
->
Resize
(
phi
::
make_ddim
({
num
}));
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
double
step
=
0
;
auto
stream
=
ctx
.
stream
();
int
block
=
512
;
int
grid
=
(
num
+
block
-
1
)
/
block
;
if
(
num
!=
1
)
{
step
=
(
static_cast
<
double
>
(
stop_data
-
start_data
))
/
(
num
-
1
);
int
block
=
512
;
int
grid
=
(
num
+
block
-
1
)
/
block
;
double
step
=
(
static_cast
<
double
>
(
stop_value
-
start_value
))
/
(
num
-
1
);
LinspaceKernelInner
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
start_
data
,
stop_data
,
step
,
num
,
out_data
);
start_
value
,
stop_value
,
step
,
num
,
out_data
);
}
else
{
LinspaceSpecialKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
start_data
,
out_data
);
LinspaceSpecialKernel
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
start_value
,
out_data
);
}
}
...
...
@@ -94,4 +124,8 @@ PD_REGISTER_KERNEL(linspace,
float
,
int32_t
,
int64_t
,
double
)
{}
double
)
{
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
1
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
python/paddle/tensor/creation.py
浏览文件 @
34cda80b
...
...
@@ -92,13 +92,13 @@ def linspace(start, stop, num, dtype=None, name=None):
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
if
not
isinstance
(
start
,
Variable
):
with
device_guard
(
"cpu"
):
tensor_start
=
fill_constant
([
1
],
dtype
,
start
)
tensor_start
=
fill_constant
([
1
],
dtype
,
start
,
force_cpu
=
True
)
if
not
isinstance
(
stop
,
Variable
):
with
device_guard
(
"cpu"
):
tensor_stop
=
fill_constant
([
1
],
dtype
,
stop
)
tensor_stop
=
fill_constant
([
1
],
dtype
,
stop
,
force_cpu
=
True
)
if
not
isinstance
(
num
,
Variable
):
with
device_guard
(
"cpu"
):
tensor_num
=
fill_constant
([
1
],
'int32'
,
num
)
tensor_num
=
fill_constant
([
1
],
'int32'
,
num
,
force_cpu
=
True
)
if
_non_static_mode
():
return
_C_ops
.
linspace
(
tensor_start
,
tensor_stop
,
tensor_num
,
'dtype'
,
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录