Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
40b30f50
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
40b30f50
编写于
3月 30, 2023
作者:
R
Roc
提交者:
GitHub
3月 30, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AMP OP&Test] add fp16 test for linspace (#52161)
上级
73544322
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
103 addition
and
33 deletion
+103
-33
paddle/phi/kernels/gpu/linspace_kernel.cu
paddle/phi/kernels/gpu/linspace_kernel.cu
+6
-3
python/paddle/fluid/tests/unittests/test_linspace.py
python/paddle/fluid/tests/unittests/test_linspace.py
+91
-27
python/paddle/tensor/creation.py
python/paddle/tensor/creation.py
+6
-3
未找到文件。
paddle/phi/kernels/gpu/linspace_kernel.cu
浏览文件 @
40b30f50
...
...
@@ -29,9 +29,10 @@ __global__ void LinspaceKernelInner(
for
(;
index
<
size
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
index
<
size
/
2
)
{
out
[
index
]
=
static_cast
<
T
>
(
sta
rt
+
step
*
index
);
out
[
index
]
=
static_cast
<
T
>
(
sta
tic_cast
<
double
>
(
start
)
+
step
*
index
);
}
else
{
out
[
index
]
=
static_cast
<
T
>
(
stop
-
step
*
(
size
-
index
-
1
));
out
[
index
]
=
static_cast
<
T
>
(
static_cast
<
double
>
(
stop
)
-
step
*
(
size
-
index
-
1
));
}
}
}
...
...
@@ -111,7 +112,9 @@ PD_REGISTER_KERNEL(linspace,
float
,
int32_t
,
int64_t
,
double
)
{
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
1
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
...
...
python/paddle/fluid/tests/unittests/test_linspace.py
浏览文件 @
40b30f50
...
...
@@ -15,7 +15,7 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
,
paddle_static_guard
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
,
paddle_static_guard
import
paddle
from
paddle
import
fluid
...
...
@@ -26,56 +26,120 @@ class TestLinspaceOpCommonCase(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"linspace"
self
.
python_api
=
paddle
.
linspace
dtype
=
'float32'
self
.
_set_dtype
()
self
.
_set_data
()
self
.
attrs
=
{
'dtype'
:
self
.
attr_dtype
}
def
_set_dtype
(
self
):
self
.
dtype
=
"float32"
self
.
attr_dtype
=
int
(
core
.
VarDesc
.
VarType
.
FP32
)
def
_set_data
(
self
):
self
.
outputs
=
{
'Out'
:
np
.
arange
(
0
,
11
).
astype
(
self
.
dtype
)}
self
.
inputs
=
{
'Start'
:
np
.
array
([
0
]).
astype
(
dtype
),
'Stop'
:
np
.
array
([
10
]).
astype
(
dtype
),
'Start'
:
np
.
array
([
0
]).
astype
(
self
.
dtype
),
'Stop'
:
np
.
array
([
10
]).
astype
(
self
.
dtype
),
'Num'
:
np
.
array
([
11
]).
astype
(
'int32'
),
}
self
.
attrs
=
{
'dtype'
:
int
(
core
.
VarDesc
.
VarType
.
FP32
)}
self
.
outputs
=
{
'Out'
:
np
.
arange
(
0
,
11
).
astype
(
dtype
)}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestLinspaceOpReverseCase
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"linspace"
self
.
python_api
=
paddle
.
linspace
dtype
=
'float32'
class
TestLinspaceOpReverseCase
(
TestLinspaceOpCommonCase
):
def
_set_data
(
self
):
self
.
inputs
=
{
'Start'
:
np
.
array
([
10
]).
astype
(
dtype
),
'Stop'
:
np
.
array
([
0
]).
astype
(
dtype
),
'Start'
:
np
.
array
([
10
]).
astype
(
self
.
dtype
),
'Stop'
:
np
.
array
([
0
]).
astype
(
self
.
dtype
),
'Num'
:
np
.
array
([
11
]).
astype
(
'int32'
),
}
self
.
attrs
=
{
'dtype'
:
int
(
core
.
VarDesc
.
VarType
.
FP32
)}
self
.
outputs
=
{
'Out'
:
np
.
arange
(
10
,
-
1
,
-
1
).
astype
(
dtype
)}
self
.
outputs
=
{
'Out'
:
np
.
arange
(
10
,
-
1
,
-
1
).
astype
(
self
.
dtype
)}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestLinspaceOpNumOneCase
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"linspace"
self
.
python_api
=
paddle
.
linspace
dtype
=
'float32'
class
TestLinspaceOpNumOneCase
(
TestLinspaceOpCommonCase
):
def
_set_data
(
self
):
self
.
inputs
=
{
'Start'
:
np
.
array
([
10
]).
astype
(
dtype
),
'Stop'
:
np
.
array
([
0
]).
astype
(
dtype
),
'Start'
:
np
.
array
([
10
]).
astype
(
self
.
dtype
),
'Stop'
:
np
.
array
([
0
]).
astype
(
self
.
dtype
),
'Num'
:
np
.
array
([
1
]).
astype
(
'int32'
),
}
self
.
attrs
=
{
'dtype'
:
int
(
core
.
VarDesc
.
VarType
.
FP32
)}
self
.
outputs
=
{
'Out'
:
np
.
array
(
10
,
dtype
=
dtype
)}
self
.
outputs
=
{
'Out'
:
np
.
array
(
10
,
dtype
=
self
.
dtype
)}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestLinspaceOpCommonCaseFP16
(
TestLinspaceOpCommonCase
):
def
_set_dtype
(
self
):
self
.
dtype
=
np
.
float16
self
.
attr_dtype
=
int
(
core
.
VarDesc
.
VarType
.
FP16
)
class
TestLinspaceOpReverseCaseFP16
(
TestLinspaceOpReverseCase
):
def
_set_dtype
(
self
):
self
.
dtype
=
np
.
float16
self
.
attr_dtype
=
int
(
core
.
VarDesc
.
VarType
.
FP16
)
class
TestLinspaceOpNumOneCaseFP16
(
TestLinspaceOpNumOneCase
):
def
_set_dtype
(
self
):
self
.
dtype
=
np
.
float16
self
.
attr_dtype
=
int
(
core
.
VarDesc
.
VarType
.
FP16
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
'not supported bf16'
,
)
class
TestLinspaceOpCommonCaseBF16
(
TestLinspaceOpCommonCaseFP16
):
def
_set_dtype
(
self
):
self
.
dtype
=
np
.
uint16
self
.
attr_dtype
=
int
(
core
.
VarDesc
.
VarType
.
BF16
)
def
_set_data
(
self
):
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
np
.
arange
(
0
,
11
).
astype
(
"float32"
))
}
self
.
inputs
=
{
'Start'
:
convert_float_to_uint16
(
np
.
array
([
0
]).
astype
(
"float32"
)),
'Stop'
:
convert_float_to_uint16
(
np
.
array
([
10
]).
astype
(
"float32"
)),
'Num'
:
np
.
array
([
11
]).
astype
(
'int32'
),
}
def
test_check_output
(
self
):
return
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
))
class
TestLinspaceOpReverseCaseBF16
(
TestLinspaceOpCommonCaseBF16
):
def
_set_data
(
self
):
self
.
inputs
=
{
'Start'
:
convert_float_to_uint16
(
np
.
array
([
10
]).
astype
(
"float32"
)),
'Stop'
:
convert_float_to_uint16
(
np
.
array
([
0
]).
astype
(
"float32"
)),
'Num'
:
np
.
array
([
11
]).
astype
(
'int32'
),
}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
np
.
arange
(
10
,
-
1
,
-
1
).
astype
(
"float32"
)
)
}
class
TestLinspaceOpNumOneCaseBF16
(
TestLinspaceOpCommonCaseBF16
):
def
_set_data
(
self
):
self
.
inputs
=
{
'Start'
:
convert_float_to_uint16
(
np
.
array
([
10
]).
astype
(
"float32"
)),
'Stop'
:
convert_float_to_uint16
(
np
.
array
([
0
]).
astype
(
"float32"
)),
'Num'
:
np
.
array
([
1
]).
astype
(
'int32'
),
}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
np
.
array
(
10
,
dtype
=
"float32"
))
}
class
TestLinspaceAPI
(
unittest
.
TestCase
):
def
test_variable_input1
(
self
):
with
paddle_static_guard
():
...
...
python/paddle/tensor/creation.py
浏览文件 @
40b30f50
...
...
@@ -332,7 +332,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype
(
start
.
dtype
,
'start'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'float16'
,
'bfloat16'
],
'linspace'
,
)
else
:
...
...
@@ -342,7 +342,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype
(
stop
.
dtype
,
'stop'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'float16'
,
'bfloat16'
],
'linspace'
,
)
else
:
...
...
@@ -350,7 +350,10 @@ def linspace(start, stop, num, dtype=None, name=None):
if
isinstance
(
num
,
Variable
):
check_dtype
(
num
.
dtype
,
'num'
,
[
'int32'
],
'linspace'
)
check_dtype
(
dtype
,
'dtype'
,
[
'int32'
,
'int64'
,
'float32'
,
'float64'
],
'linspace'
dtype
,
'dtype'
,
[
'int32'
,
'int64'
,
'float32'
,
'float64'
,
'float16'
,
'bfloat16'
],
'linspace'
,
)
if
(
(
stop_dtype
==
"float64"
or
start_dtype
==
"float64"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录