Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
041f4ab8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
041f4ab8
编写于
9月 06, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
9月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine linspace Op for dtype setting(#27071)
上级
92530ca4
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
58 addition
and
12 deletion
+58
-12
paddle/fluid/operators/linspace_op.cu
paddle/fluid/operators/linspace_op.cu
+19
-4
paddle/fluid/operators/linspace_op.h
paddle/fluid/operators/linspace_op.h
+20
-2
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+15
-2
python/paddle/fluid/tests/unittests/test_linspace.py
python/paddle/fluid/tests/unittests/test_linspace.py
+4
-4
未找到文件。
paddle/fluid/operators/linspace_op.cu
浏览文件 @
041f4ab8
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
@@ -19,6 +20,8 @@ limitations under the License. */
...
@@ -19,6 +20,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
LinspaceKernel
(
T
start
,
double
step
,
int64_t
size
,
T
*
out
)
{
__global__
void
LinspaceKernel
(
T
start
,
double
step
,
int64_t
size
,
T
*
out
)
{
CUDA_KERNEL_LOOP
(
index
,
size
)
{
CUDA_KERNEL_LOOP
(
index
,
size
)
{
...
@@ -35,15 +38,27 @@ template <typename T>
...
@@ -35,15 +38,27 @@ template <typename T>
class
CUDALinspaceKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CUDALinspaceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
start_
t
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
);
auto
*
pre_star
t
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
);
auto
*
stop_t
=
context
.
Input
<
framework
::
Tensor
>
(
"Stop"
);
auto
*
pre_stop
=
context
.
Input
<
framework
::
Tensor
>
(
"Stop"
);
auto
*
num_t
=
context
.
Input
<
framework
::
Tensor
>
(
"Num"
);
auto
*
num_t
=
context
.
Input
<
framework
::
Tensor
>
(
"Num"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
context
.
Attr
<
int
>
(
"dtype"
));
Tensor
start_t
;
Tensor
stop_t
;
auto
start_dtype
=
framework
::
OpKernelType
(
pre_start
->
type
(),
context
.
GetPlace
());
auto
stop_dtype
=
framework
::
OpKernelType
(
pre_stop
->
type
(),
context
.
GetPlace
());
auto
out_dtype
=
framework
::
OpKernelType
(
dtype
,
context
.
GetPlace
());
framework
::
TransDataType
(
start_dtype
,
out_dtype
,
*
pre_start
,
&
start_t
);
framework
::
TransDataType
(
stop_dtype
,
out_dtype
,
*
pre_stop
,
&
stop_t
);
framework
::
Tensor
n
;
framework
::
Tensor
n
;
framework
::
TensorCopy
(
*
start_t
,
platform
::
CPUPlace
(),
&
n
);
framework
::
TensorCopy
(
start_t
,
platform
::
CPUPlace
(),
&
n
);
T
start
=
n
.
data
<
T
>
()[
0
];
T
start
=
n
.
data
<
T
>
()[
0
];
framework
::
TensorCopy
(
*
stop_t
,
platform
::
CPUPlace
(),
&
n
);
framework
::
TensorCopy
(
stop_t
,
platform
::
CPUPlace
(),
&
n
);
T
stop
=
n
.
data
<
T
>
()[
0
];
T
stop
=
n
.
data
<
T
>
()[
0
];
framework
::
TensorCopy
(
*
num_t
,
platform
::
CPUPlace
(),
&
n
);
framework
::
TensorCopy
(
*
num_t
,
platform
::
CPUPlace
(),
&
n
);
int32_t
num
=
n
.
data
<
int32_t
>
()[
0
];
int32_t
num
=
n
.
data
<
int32_t
>
()[
0
];
...
...
paddle/fluid/operators/linspace_op.h
浏览文件 @
041f4ab8
...
@@ -14,20 +14,38 @@ limitations under the License. */
...
@@ -14,20 +14,38 @@ limitations under the License. */
#pragma once
#pragma once
#include <functional>
#include <functional>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
template
<
typename
T
>
class
CPULinspaceKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CPULinspaceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
T
start
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
)
->
data
<
T
>
()[
0
]
;
auto
*
pre_start
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
)
;
T
stop
=
context
.
Input
<
framework
::
Tensor
>
(
"Stop"
)
->
data
<
T
>
()[
0
]
;
auto
*
pre_stop
=
context
.
Input
<
framework
::
Tensor
>
(
"Stop"
)
;
int32_t
num
=
context
.
Input
<
framework
::
Tensor
>
(
"Num"
)
->
data
<
int32_t
>
()[
0
];
int32_t
num
=
context
.
Input
<
framework
::
Tensor
>
(
"Num"
)
->
data
<
int32_t
>
()[
0
];
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
context
.
Attr
<
int
>
(
"dtype"
));
Tensor
start_t
;
Tensor
stop_t
;
auto
start_dtype
=
framework
::
OpKernelType
(
pre_start
->
type
(),
context
.
GetPlace
());
auto
stop_dtype
=
framework
::
OpKernelType
(
pre_stop
->
type
(),
context
.
GetPlace
());
auto
out_dtype
=
framework
::
OpKernelType
(
dtype
,
context
.
GetPlace
());
framework
::
TransDataType
(
start_dtype
,
out_dtype
,
*
pre_start
,
&
start_t
);
framework
::
TransDataType
(
stop_dtype
,
out_dtype
,
*
pre_stop
,
&
stop_t
);
T
start
=
start_t
.
data
<
T
>
()[
0
];
T
stop
=
stop_t
.
data
<
T
>
()[
0
];
PADDLE_ENFORCE
(
num
>
0
,
"The num of linspace op should be larger than 0."
);
PADDLE_ENFORCE
(
num
>
0
,
"The num of linspace op should be larger than 0."
);
out
->
Resize
(
framework
::
make_ddim
({
num
}));
out
->
Resize
(
framework
::
make_ddim
({
num
}));
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
041f4ab8
...
@@ -1462,19 +1462,32 @@ def linspace(start, stop, num, dtype=None, name=None):
...
@@ -1462,19 +1462,32 @@ def linspace(start, stop, num, dtype=None, name=None):
helper
=
LayerHelper
(
"linspace"
,
**
locals
())
helper
=
LayerHelper
(
"linspace"
,
**
locals
())
start_dtype
=
convert_dtype
(
tensor_start
.
dtype
)
stop_dtype
=
convert_dtype
(
tensor_stop
.
dtype
)
out_dtype
=
convert_dtype
(
dtype
)
if
isinstance
(
start
,
Variable
):
if
isinstance
(
start
,
Variable
):
check_dtype
(
start
.
dtype
,
'start'
,
(
convert_dtype
(
dtype
)),
'linspace'
)
check_dtype
(
start
.
dtype
,
'start'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'linspace'
)
else
:
else
:
check_type
(
start
,
'start'
,
(
int
,
float
),
'linspace'
)
check_type
(
start
,
'start'
,
(
int
,
float
),
'linspace'
)
if
isinstance
(
stop
,
Variable
):
if
isinstance
(
stop
,
Variable
):
check_dtype
(
stop
.
dtype
,
'stop'
,
(
convert_dtype
(
dtype
)),
'linspace'
)
check_dtype
(
stop
.
dtype
,
'stop'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'linspace'
)
else
:
else
:
check_type
(
stop
,
'stop'
,
(
int
,
float
),
'linspace'
)
check_type
(
stop
,
'stop'
,
(
int
,
float
),
'linspace'
)
if
isinstance
(
num
,
Variable
):
if
isinstance
(
num
,
Variable
):
check_dtype
(
num
.
dtype
,
'num'
,
[
'int32'
],
'linspace'
)
check_dtype
(
num
.
dtype
,
'num'
,
[
'int32'
],
'linspace'
)
check_dtype
(
dtype
,
'dtype'
,
[
'int32'
,
'int64'
,
'float32'
,
'float64'
],
check_dtype
(
dtype
,
'dtype'
,
[
'int32'
,
'int64'
,
'float32'
,
'float64'
],
'linspace'
)
'linspace'
)
if
((
stop_dtype
==
"float64"
or
start_dtype
==
"float64"
)
and
out_dtype
in
[
"float32"
,
"int32"
])
or
((
stop_dtype
==
"int64"
or
start_dtype
==
"int64"
)
and
out_dtype
==
"int32"
):
raise
ValueError
(
"The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, "
"which may cause data type overflows. Please reset attr(dtype) of linspace."
.
format
(
start_dtype
,
stop_dtype
,
dtype
))
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
...
...
python/paddle/fluid/tests/unittests/test_linspace.py
浏览文件 @
041f4ab8
...
@@ -154,16 +154,16 @@ class TestLinspaceOpError(unittest.TestCase):
...
@@ -154,16 +154,16 @@ class TestLinspaceOpError(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
test_step_dtype
)
self
.
assertRaises
(
TypeError
,
test_step_dtype
)
def
test_start_dtype
():
def
test_start_dtype
():
start
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"
int32
"
,
name
=
"start"
)
start
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"
float64
"
,
name
=
"start"
)
fluid
.
layers
.
linspace
(
start
,
10
,
1
,
dtype
=
"float32"
)
fluid
.
layers
.
linspace
(
start
,
10
,
1
,
dtype
=
"float32"
)
self
.
assertRaises
(
Typ
eError
,
test_start_dtype
)
self
.
assertRaises
(
Valu
eError
,
test_start_dtype
)
def
test_end_dtype
():
def
test_end_dtype
():
end
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"
int32
"
,
name
=
"end"
)
end
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"
float64
"
,
name
=
"end"
)
fluid
.
layers
.
linspace
(
0
,
end
,
1
,
dtype
=
"float32"
)
fluid
.
layers
.
linspace
(
0
,
end
,
1
,
dtype
=
"float32"
)
self
.
assertRaises
(
Typ
eError
,
test_end_dtype
)
self
.
assertRaises
(
Valu
eError
,
test_end_dtype
)
def
test_num_dtype
():
def
test_num_dtype
():
num
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"int32"
,
name
=
"step"
)
num
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"int32"
,
name
=
"step"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录