Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
34cda80b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
34cda80b
编写于
5月 16, 2022
作者:
Y
Yiqun Liu
提交者:
GitHub
5月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize linspace to avoid GPU -> CPU copy. (#42750)
上级
5924458b
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
62 addition
and
25 deletion
+62
-25
paddle/fluid/operators/linspace_op.cc
paddle/fluid/operators/linspace_op.cc
+5
-2
paddle/phi/kernels/gpu/linspace_kernel.cu
paddle/phi/kernels/gpu/linspace_kernel.cu
+54
-20
python/paddle/tensor/creation.py
python/paddle/tensor/creation.py
+3
-3
未找到文件。
paddle/fluid/operators/linspace_op.cc
浏览文件 @
34cda80b
...
@@ -38,9 +38,12 @@ class LinspaceOp : public framework::OperatorWithKernel {
...
@@ -38,9 +38,12 @@ class LinspaceOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetKernelTypeForVar
(
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
if
(
platform
::
is_xpu_place
(
tensor
.
place
()))
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
tensor
.
place
(),
tensor
.
layout
());
}
}
return
expected_kernel_type
;
}
};
};
class
LinspaceOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
LinspaceOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/phi/kernels/gpu/linspace_kernel.cu
浏览文件 @
34cda80b
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
namespace
phi
{
...
@@ -42,6 +41,47 @@ __global__ void LinspaceSpecialKernel(T start, T* out) {
...
@@ -42,6 +41,47 @@ __global__ void LinspaceSpecialKernel(T start, T* out) {
out
[
0
]
=
static_cast
<
T
>
(
start
);
out
[
0
]
=
static_cast
<
T
>
(
start
);
}
}
template
<
typename
T
,
typename
Context
>
T
GetValue
(
const
Context
&
ctx
,
const
DenseTensor
&
x
)
{
T
value
=
static_cast
<
T
>
(
0
);
if
(
x
.
place
()
!=
CPUPlace
())
{
DenseTensor
cpu_x
;
Copy
(
ctx
,
x
,
CPUPlace
(),
true
,
&
cpu_x
);
value
=
cpu_x
.
data
<
T
>
()[
0
];
}
else
{
value
=
x
.
data
<
T
>
()[
0
];
}
return
value
;
}
template
<
typename
T
,
typename
Context
>
T
GetValueOfExpectedType
(
const
Context
&
ctx
,
const
DenseTensor
&
x
)
{
switch
(
x
.
dtype
())
{
case
DataType
::
FLOAT32
:
return
static_cast
<
T
>
(
GetValue
<
float
,
Context
>
(
ctx
,
x
));
case
DataType
::
FLOAT64
:
return
static_cast
<
T
>
(
GetValue
<
double
,
Context
>
(
ctx
,
x
));
case
DataType
::
INT32
:
return
static_cast
<
T
>
(
GetValue
<
int32_t
,
Context
>
(
ctx
,
x
));
case
DataType
::
INT64
:
return
static_cast
<
T
>
(
GetValue
<
int64_t
,
Context
>
(
ctx
,
x
));
case
DataType
::
FLOAT16
:
return
static_cast
<
T
>
(
GetValue
<
phi
::
dtype
::
float16
,
Context
>
(
ctx
,
x
));
case
DataType
::
BFLOAT16
:
return
static_cast
<
T
>
(
GetValue
<
phi
::
dtype
::
bfloat16
,
Context
>
(
ctx
,
x
));
case
DataType
::
BOOL
:
return
static_cast
<
T
>
(
GetValue
<
bool
,
Context
>
(
ctx
,
x
));
case
DataType
::
INT16
:
return
static_cast
<
T
>
(
GetValue
<
int16_t
,
Context
>
(
ctx
,
x
));
case
DataType
::
UINT8
:
return
static_cast
<
T
>
(
GetValue
<
uint8_t
,
Context
>
(
ctx
,
x
));
default:
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when casting data type."
,
x
.
dtype
()));
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
LinspaceKernel
(
const
Context
&
ctx
,
void
LinspaceKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
start
,
const
DenseTensor
&
start
,
...
@@ -49,18 +89,9 @@ void LinspaceKernel(const Context& ctx,
...
@@ -49,18 +89,9 @@ void LinspaceKernel(const Context& ctx,
const
DenseTensor
&
number
,
const
DenseTensor
&
number
,
DataType
dtype
,
DataType
dtype
,
DenseTensor
*
out
)
{
DenseTensor
*
out
)
{
auto
start_t
=
phi
::
funcs
::
TransDataType
(
ctx
,
start
,
dtype
);
T
start_value
=
GetValueOfExpectedType
<
T
,
Context
>
(
ctx
,
start
);
auto
stop_t
=
phi
::
funcs
::
TransDataType
(
ctx
,
stop
,
dtype
);
T
stop_value
=
GetValueOfExpectedType
<
T
,
Context
>
(
ctx
,
stop
);
int64_t
num
=
GetValueOfExpectedType
<
int64_t
,
Context
>
(
ctx
,
number
);
DenseTensor
n_start
;
DenseTensor
n_stop
;
DenseTensor
n_num
;
phi
::
Copy
(
ctx
,
start_t
,
phi
::
CPUPlace
(),
false
,
&
n_start
);
T
start_data
=
n_start
.
data
<
T
>
()[
0
];
phi
::
Copy
(
ctx
,
stop_t
,
phi
::
CPUPlace
(),
false
,
&
n_stop
);
T
stop_data
=
n_stop
.
data
<
T
>
()[
0
];
phi
::
Copy
(
ctx
,
number
,
phi
::
CPUPlace
(),
false
,
&
n_num
);
int64_t
num
=
static_cast
<
int64_t
>
(
n_num
.
data
<
int32_t
>
()[
0
]);
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
num
,
num
,
...
@@ -72,16 +103,15 @@ void LinspaceKernel(const Context& ctx,
...
@@ -72,16 +103,15 @@ void LinspaceKernel(const Context& ctx,
out
->
Resize
(
phi
::
make_ddim
({
num
}));
out
->
Resize
(
phi
::
make_ddim
({
num
}));
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
double
step
=
0
;
auto
stream
=
ctx
.
stream
();
auto
stream
=
ctx
.
stream
();
if
(
num
!=
1
)
{
int
block
=
512
;
int
block
=
512
;
int
grid
=
(
num
+
block
-
1
)
/
block
;
int
grid
=
(
num
+
block
-
1
)
/
block
;
if
(
num
!=
1
)
{
double
step
=
(
static_cast
<
double
>
(
stop_value
-
start_value
))
/
(
num
-
1
);
step
=
(
static_cast
<
double
>
(
stop_data
-
start_data
))
/
(
num
-
1
);
LinspaceKernelInner
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
LinspaceKernelInner
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
start_
data
,
stop_data
,
step
,
num
,
out_data
);
start_
value
,
stop_value
,
step
,
num
,
out_data
);
}
else
{
}
else
{
LinspaceSpecialKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
start_data
,
out_data
);
LinspaceSpecialKernel
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
start_value
,
out_data
);
}
}
}
}
...
@@ -94,4 +124,8 @@ PD_REGISTER_KERNEL(linspace,
...
@@ -94,4 +124,8 @@ PD_REGISTER_KERNEL(linspace,
float
,
float
,
int32_t
,
int32_t
,
int64_t
,
int64_t
,
double
)
{}
double
)
{
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
1
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
python/paddle/tensor/creation.py
浏览文件 @
34cda80b
...
@@ -92,13 +92,13 @@ def linspace(start, stop, num, dtype=None, name=None):
...
@@ -92,13 +92,13 @@ def linspace(start, stop, num, dtype=None, name=None):
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
if
not
isinstance
(
start
,
Variable
):
if
not
isinstance
(
start
,
Variable
):
with
device_guard
(
"cpu"
):
with
device_guard
(
"cpu"
):
tensor_start
=
fill_constant
([
1
],
dtype
,
start
)
tensor_start
=
fill_constant
([
1
],
dtype
,
start
,
force_cpu
=
True
)
if
not
isinstance
(
stop
,
Variable
):
if
not
isinstance
(
stop
,
Variable
):
with
device_guard
(
"cpu"
):
with
device_guard
(
"cpu"
):
tensor_stop
=
fill_constant
([
1
],
dtype
,
stop
)
tensor_stop
=
fill_constant
([
1
],
dtype
,
stop
,
force_cpu
=
True
)
if
not
isinstance
(
num
,
Variable
):
if
not
isinstance
(
num
,
Variable
):
with
device_guard
(
"cpu"
):
with
device_guard
(
"cpu"
):
tensor_num
=
fill_constant
([
1
],
'int32'
,
num
)
tensor_num
=
fill_constant
([
1
],
'int32'
,
num
,
force_cpu
=
True
)
if
_non_static_mode
():
if
_non_static_mode
():
return
_C_ops
.
linspace
(
tensor_start
,
tensor_stop
,
tensor_num
,
'dtype'
,
return
_C_ops
.
linspace
(
tensor_start
,
tensor_stop
,
tensor_num
,
'dtype'
,
dtype
)
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录