Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3d09929b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3d09929b
编写于
11月 18, 2020
作者:
L
Leo Chen
提交者:
GitHub
11月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add check for non-dispensable input (#28666)
* Add check for non-dispensable input * fix typo
上级
19226ba8
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
18 addition
and
5 deletion
+18
-5
paddle/fluid/pybind/op_function.h
paddle/fluid/pybind/op_function.h
+14
-2
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+4
-3
未找到文件。
paddle/fluid/pybind/op_function.h
浏览文件 @
3d09929b
...
@@ -36,9 +36,15 @@ namespace pybind {
...
@@ -36,9 +36,15 @@ namespace pybind {
static
inline
std
::
shared_ptr
<
imperative
::
VarBase
>
CastPyHandleToVarBase
(
static
inline
std
::
shared_ptr
<
imperative
::
VarBase
>
CastPyHandleToVarBase
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
int
arg_idx
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
int
arg_idx
,
const
py
::
handle
&
handle
)
{
const
py
::
handle
&
handle
,
bool
dispensable
=
false
)
{
PyObject
*
py_obj
=
handle
.
ptr
();
// get underlying PyObject
PyObject
*
py_obj
=
handle
.
ptr
();
// get underlying PyObject
if
(
!
py_obj
||
py_obj
==
Py_None
)
{
if
(
!
py_obj
||
py_obj
==
Py_None
)
{
if
(
!
dispensable
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
Py_TYPE
(
py_obj
)
->
tp_name
));
}
return
nullptr
;
return
nullptr
;
}
}
try
{
try
{
...
@@ -54,9 +60,15 @@ static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
...
@@ -54,9 +60,15 @@ static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
static
inline
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
static
inline
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
CastPyHandleToVarBaseList
(
const
std
::
string
&
op_type
,
CastPyHandleToVarBaseList
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
int
arg_idx
,
const
std
::
string
&
arg_name
,
int
arg_idx
,
const
py
::
handle
&
handle
)
{
const
py
::
handle
&
handle
,
bool
dispensable
=
false
)
{
PyObject
*
py_obj
=
handle
.
ptr
();
// get underlying PyObject
PyObject
*
py_obj
=
handle
.
ptr
();
// get underlying PyObject
if
(
!
py_obj
||
py_obj
==
Py_None
)
{
if
(
!
py_obj
||
py_obj
==
Py_None
)
{
if
(
!
dispensable
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
Py_TYPE
(
py_obj
)
->
tp_name
));
}
return
{};
return
{};
}
}
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
result
;
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
result
;
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
3d09929b
...
@@ -166,10 +166,10 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
...
@@ -166,10 +166,10 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const
char
*
OUT_VAR_LIST_TYPE
=
R"(std::vector<std::shared_ptr<imperative::VarBase>>)"
;
const
char
*
OUT_VAR_LIST_TYPE
=
R"(std::vector<std::shared_ptr<imperative::VarBase>>)"
;
const
char
*
CAST_VAR_TEMPLATE
=
R"(
const
char
*
CAST_VAR_TEMPLATE
=
R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)"
;
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s
, %s
);)"
;
const
char
*
CAST_VAR_LIST_TEMPLATE
=
R"(
const
char
*
CAST_VAR_LIST_TEMPLATE
=
R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)"
;
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s
, %s
);)"
;
const
char
*
ARG_TEMPLATE
=
R"(const %s& %s)"
;
const
char
*
ARG_TEMPLATE
=
R"(const %s& %s)"
;
...
@@ -263,9 +263,10 @@ GenerateOpFunctions(const std::string& module_name) {
...
@@ -263,9 +263,10 @@ GenerateOpFunctions(const std::string& module_name) {
input_args_num
++
;
input_args_num
++
;
const
auto
in_cast_type
=
const
auto
in_cast_type
=
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
auto
dispensable
=
input
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
op_type
,
in_name
,
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
op_type
,
in_name
,
arg_idx
++
,
TempName
(
in_name
));
arg_idx
++
,
TempName
(
in_name
)
,
dispensable
);
if
(
input
.
dispensable
())
{
if
(
input
.
dispensable
())
{
const
auto
in_template
=
input
.
duplicable
()
const
auto
in_template
=
input
.
duplicable
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录