Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b23801a2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b23801a2
编写于
6月 19, 2020
作者:
C
Chen Weihang
提交者:
GitHub
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish tensor set error messag, test=develop (#25113)
上级
542a226c
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
22 addition
and
5 deletion
+22
-5
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+6
-5
python/paddle/fluid/tests/unittests/test_tensor.py
python/paddle/fluid/tests/unittests/test_tensor.py
+16
-0
未找到文件。
paddle/fluid/pybind/tensor_py.h
浏览文件 @
b23801a2
...
...
@@ -246,12 +246,13 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
}
else
if
(
py
::
isinstance
<
py
::
array_t
<
bool
>>
(
array
))
{
SetTensorFromPyArrayT
<
bool
,
P
>
(
self
,
array
,
place
,
zero_copy
);
}
else
{
// obj may be any type, obj.cast<py::array>() may be failed,
// then the array.dtype will be string of unknown meaning,
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Incompatible data type: tensor.set() supports bool, float16, "
"float32, "
"float64, "
"int8, int16, int32, int64 and uint8, uint16, but got %s!"
,
array
.
dtype
()));
"Input object type error or incompatible array data type. "
"tensor.set() supports array with bool, float16, float32, "
"float64, int8, int16, int32, int64, uint8 or uint16, "
"please check your input or input array data type."
));
}
}
...
...
python/paddle/fluid/tests/unittests/test_tensor.py
浏览文件 @
b23801a2
...
...
@@ -345,6 +345,22 @@ class TestTensor(unittest.TestCase):
self
.
assertEqual
([
2
,
200
,
300
],
tensor
.
shape
())
self
.
assertTrue
(
numpy
.
array_equal
(
numpy
.
array
(
tensor
),
list_array
))
def
test_tensor_set_error
(
self
):
scope
=
core
.
Scope
()
var
=
scope
.
var
(
"test_tensor"
)
place
=
core
.
CPUPlace
()
tensor
=
var
.
get_tensor
()
exception
=
None
try
:
error_array
=
[
"1"
,
"2"
]
tensor
.
set
(
error_array
,
place
)
except
core
.
EnforceNotMet
as
ex
:
exception
=
ex
self
.
assertIsNotNone
(
exception
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录