Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
983fcb56
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
983fcb56
编写于
4月 26, 2022
作者:
W
Weilong Wu
提交者:
GitHub
4月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Eager] Support numpy.ndarry in CastNumpy2Scalar (#42136) (#42213)
上级
42297995
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
21 addition
and
2 deletion
+21
-2
paddle/fluid/pybind/eager_utils.cc
paddle/fluid/pybind/eager_utils.cc
+14
-1
python/paddle/fluid/tests/unittests/test_bfgs.py
python/paddle/fluid/tests/unittests/test_bfgs.py
+7
-1
未找到文件。
paddle/fluid/pybind/eager_utils.cc
浏览文件 @
983fcb56
...
@@ -1019,7 +1019,20 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
...
@@ -1019,7 +1019,20 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
PyTypeObject
*
type
=
obj
->
ob_type
;
PyTypeObject
*
type
=
obj
->
ob_type
;
auto
type_name
=
std
::
string
(
type
->
tp_name
);
auto
type_name
=
std
::
string
(
type
->
tp_name
);
VLOG
(
1
)
<<
"type_name: "
<<
type_name
;
VLOG
(
1
)
<<
"type_name: "
<<
type_name
;
if
(
type_name
==
"numpy.float64"
)
{
if
(
type_name
==
"numpy.ndarray"
&&
PySequence_Check
(
obj
))
{
PyObject
*
item
=
nullptr
;
item
=
PySequence_GetItem
(
obj
,
0
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
float
value
=
static_cast
<
float
>
(
PyFloat_AsDouble
(
item
));
return
paddle
::
experimental
::
Scalar
(
value
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) is numpy.ndarry, the inner elements "
"must be "
"numpy.float32/float64 now, but got %s"
,
op_type
,
arg_pos
+
1
,
type_name
));
// NOLINT
}
}
else
if
(
type_name
==
"numpy.float64"
)
{
double
value
=
CastPyArg2Double
(
obj
,
op_type
,
arg_pos
);
double
value
=
CastPyArg2Double
(
obj
,
op_type
,
arg_pos
);
return
paddle
::
experimental
::
Scalar
(
value
);
return
paddle
::
experimental
::
Scalar
(
value
);
}
else
if
(
type_name
==
"numpy.float32"
)
{
}
else
if
(
type_name
==
"numpy.float32"
)
{
...
...
python/paddle/fluid/tests/unittests/test_bfgs.py
浏览文件 @
983fcb56
...
@@ -20,6 +20,7 @@ import paddle
...
@@ -20,6 +20,7 @@ import paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddle.incubate.optimizer.functional.bfgs
import
minimize_bfgs
from
paddle.incubate.optimizer.functional.bfgs
import
minimize_bfgs
from
paddle.fluid.framework
import
_test_eager_guard
np
.
random
.
seed
(
123
)
np
.
random
.
seed
(
123
)
...
@@ -117,7 +118,7 @@ class TestBfgs(unittest.TestCase):
...
@@ -117,7 +118,7 @@ class TestBfgs(unittest.TestCase):
results
=
test_static_graph
(
func
,
x0
,
dtype
=
'float64'
)
results
=
test_static_graph
(
func
,
x0
,
dtype
=
'float64'
)
self
.
assertTrue
(
np
.
allclose
(
0.8
,
results
[
2
]))
self
.
assertTrue
(
np
.
allclose
(
0.8
,
results
[
2
]))
def
test
_rosenbrock
(
self
):
def
func
_rosenbrock
(
self
):
# The Rosenbrock function is a standard optimization test case.
# The Rosenbrock function is a standard optimization test case.
a
=
np
.
random
.
random
(
size
=
[
1
]).
astype
(
'float32'
)
a
=
np
.
random
.
random
(
size
=
[
1
]).
astype
(
'float32'
)
minimum
=
[
a
.
item
(),
(
a
**
2
).
item
()]
minimum
=
[
a
.
item
(),
(
a
**
2
).
item
()]
...
@@ -136,6 +137,11 @@ class TestBfgs(unittest.TestCase):
...
@@ -136,6 +137,11 @@ class TestBfgs(unittest.TestCase):
results
=
test_dynamic_graph
(
func
,
x0
)
results
=
test_dynamic_graph
(
func
,
x0
)
self
.
assertTrue
(
np
.
allclose
(
minimum
,
results
[
2
]))
self
.
assertTrue
(
np
.
allclose
(
minimum
,
results
[
2
]))
def
test_rosenbrock
(
self
):
with
_test_eager_guard
():
self
.
func_rosenbrock
()
self
.
func_rosenbrock
()
def
test_exception
(
self
):
def
test_exception
(
self
):
def
func
(
x
):
def
func
(
x
):
return
paddle
.
dot
(
x
,
x
)
return
paddle
.
dot
(
x
,
x
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录