Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
041f4ab8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
041f4ab8
编写于
9月 06, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
9月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine linspace Op for dtype setting(#27071)
上级
92530ca4
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
58 addition
and
12 deletion
+58
-12
paddle/fluid/operators/linspace_op.cu
paddle/fluid/operators/linspace_op.cu
+19
-4
paddle/fluid/operators/linspace_op.h
paddle/fluid/operators/linspace_op.h
+20
-2
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+15
-2
python/paddle/fluid/tests/unittests/test_linspace.py
python/paddle/fluid/tests/unittests/test_linspace.py
+4
-4
未找到文件。
paddle/fluid/operators/linspace_op.cu
浏览文件 @
041f4ab8
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
@@ -19,6 +20,8 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__global__
void
LinspaceKernel
(
T
start
,
double
step
,
int64_t
size
,
T
*
out
)
{
CUDA_KERNEL_LOOP
(
index
,
size
)
{
...
...
@@ -35,15 +38,27 @@ template <typename T>
class
CUDALinspaceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
start_
t
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
);
auto
*
stop_t
=
context
.
Input
<
framework
::
Tensor
>
(
"Stop"
);
auto
*
pre_star
t
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
);
auto
*
pre_stop
=
context
.
Input
<
framework
::
Tensor
>
(
"Stop"
);
auto
*
num_t
=
context
.
Input
<
framework
::
Tensor
>
(
"Num"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
context
.
Attr
<
int
>
(
"dtype"
));
Tensor
start_t
;
Tensor
stop_t
;
auto
start_dtype
=
framework
::
OpKernelType
(
pre_start
->
type
(),
context
.
GetPlace
());
auto
stop_dtype
=
framework
::
OpKernelType
(
pre_stop
->
type
(),
context
.
GetPlace
());
auto
out_dtype
=
framework
::
OpKernelType
(
dtype
,
context
.
GetPlace
());
framework
::
TransDataType
(
start_dtype
,
out_dtype
,
*
pre_start
,
&
start_t
);
framework
::
TransDataType
(
stop_dtype
,
out_dtype
,
*
pre_stop
,
&
stop_t
);
framework
::
Tensor
n
;
framework
::
TensorCopy
(
*
start_t
,
platform
::
CPUPlace
(),
&
n
);
framework
::
TensorCopy
(
start_t
,
platform
::
CPUPlace
(),
&
n
);
T
start
=
n
.
data
<
T
>
()[
0
];
framework
::
TensorCopy
(
*
stop_t
,
platform
::
CPUPlace
(),
&
n
);
framework
::
TensorCopy
(
stop_t
,
platform
::
CPUPlace
(),
&
n
);
T
stop
=
n
.
data
<
T
>
()[
0
];
framework
::
TensorCopy
(
*
num_t
,
platform
::
CPUPlace
(),
&
n
);
int32_t
num
=
n
.
data
<
int32_t
>
()[
0
];
...
...
paddle/fluid/operators/linspace_op.h
浏览文件 @
041f4ab8
...
...
@@ -14,20 +14,38 @@ limitations under the License. */
#pragma once
#include <functional>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
CPULinspaceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
T
start
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
)
->
data
<
T
>
()[
0
]
;
T
stop
=
context
.
Input
<
framework
::
Tensor
>
(
"Stop"
)
->
data
<
T
>
()[
0
]
;
auto
*
pre_start
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
)
;
auto
*
pre_stop
=
context
.
Input
<
framework
::
Tensor
>
(
"Stop"
)
;
int32_t
num
=
context
.
Input
<
framework
::
Tensor
>
(
"Num"
)
->
data
<
int32_t
>
()[
0
];
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
context
.
Attr
<
int
>
(
"dtype"
));
Tensor
start_t
;
Tensor
stop_t
;
auto
start_dtype
=
framework
::
OpKernelType
(
pre_start
->
type
(),
context
.
GetPlace
());
auto
stop_dtype
=
framework
::
OpKernelType
(
pre_stop
->
type
(),
context
.
GetPlace
());
auto
out_dtype
=
framework
::
OpKernelType
(
dtype
,
context
.
GetPlace
());
framework
::
TransDataType
(
start_dtype
,
out_dtype
,
*
pre_start
,
&
start_t
);
framework
::
TransDataType
(
stop_dtype
,
out_dtype
,
*
pre_stop
,
&
stop_t
);
T
start
=
start_t
.
data
<
T
>
()[
0
];
T
stop
=
stop_t
.
data
<
T
>
()[
0
];
PADDLE_ENFORCE
(
num
>
0
,
"The num of linspace op should be larger than 0."
);
out
->
Resize
(
framework
::
make_ddim
({
num
}));
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
041f4ab8
...
...
@@ -1462,19 +1462,32 @@ def linspace(start, stop, num, dtype=None, name=None):
helper
=
LayerHelper
(
"linspace"
,
**
locals
())
start_dtype
=
convert_dtype
(
tensor_start
.
dtype
)
stop_dtype
=
convert_dtype
(
tensor_stop
.
dtype
)
out_dtype
=
convert_dtype
(
dtype
)
if
isinstance
(
start
,
Variable
):
check_dtype
(
start
.
dtype
,
'start'
,
(
convert_dtype
(
dtype
)),
'linspace'
)
check_dtype
(
start
.
dtype
,
'start'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'linspace'
)
else
:
check_type
(
start
,
'start'
,
(
int
,
float
),
'linspace'
)
if
isinstance
(
stop
,
Variable
):
check_dtype
(
stop
.
dtype
,
'stop'
,
(
convert_dtype
(
dtype
)),
'linspace'
)
check_dtype
(
stop
.
dtype
,
'stop'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'linspace'
)
else
:
check_type
(
stop
,
'stop'
,
(
int
,
float
),
'linspace'
)
if
isinstance
(
num
,
Variable
):
check_dtype
(
num
.
dtype
,
'num'
,
[
'int32'
],
'linspace'
)
check_dtype
(
dtype
,
'dtype'
,
[
'int32'
,
'int64'
,
'float32'
,
'float64'
],
'linspace'
)
if
((
stop_dtype
==
"float64"
or
start_dtype
==
"float64"
)
and
out_dtype
in
[
"float32"
,
"int32"
])
or
((
stop_dtype
==
"int64"
or
start_dtype
==
"int64"
)
and
out_dtype
==
"int32"
):
raise
ValueError
(
"The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, "
"which may cause data type overflows. Please reset attr(dtype) of linspace."
.
format
(
start_dtype
,
stop_dtype
,
dtype
))
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
...
...
python/paddle/fluid/tests/unittests/test_linspace.py
浏览文件 @
041f4ab8
...
...
@@ -154,16 +154,16 @@ class TestLinspaceOpError(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
test_step_dtype
)
def
test_start_dtype
():
start
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"
int32
"
,
name
=
"start"
)
start
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"
float64
"
,
name
=
"start"
)
fluid
.
layers
.
linspace
(
start
,
10
,
1
,
dtype
=
"float32"
)
self
.
assertRaises
(
Typ
eError
,
test_start_dtype
)
self
.
assertRaises
(
Valu
eError
,
test_start_dtype
)
def
test_end_dtype
():
end
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"
int32
"
,
name
=
"end"
)
end
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"
float64
"
,
name
=
"end"
)
fluid
.
layers
.
linspace
(
0
,
end
,
1
,
dtype
=
"float32"
)
self
.
assertRaises
(
Typ
eError
,
test_end_dtype
)
self
.
assertRaises
(
Valu
eError
,
test_end_dtype
)
def
test_num_dtype
():
num
=
fluid
.
data
(
shape
=
[
1
],
dtype
=
"int32"
,
name
=
"step"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录