Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
672578a7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
672578a7
编写于
8月 17, 2020
作者:
L
Leo Chen
提交者:
GitHub
8月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Print user-friendly error message in core.ops (#26261)
* print user-friendly error message * adjust error sumary
上级
d7bdc9fe
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
91 addition
and
7 deletion
+91
-7
paddle/fluid/platform/enforce.h
paddle/fluid/platform/enforce.h
+1
-1
paddle/fluid/pybind/op_function.h
paddle/fluid/pybind/op_function.h
+59
-0
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+31
-6
未找到文件。
paddle/fluid/platform/enforce.h
浏览文件 @
672578a7
...
...
@@ -266,7 +266,7 @@ inline std::string GetErrorSumaryString(StrType&& what, const char* file,
std
::
ostringstream
sout
;
sout
<<
"
\n
----------------------
\n
Error Message "
"Summary:
\n
----------------------
\n
"
;
sout
<<
string
::
Sprintf
(
"%s
at (
%s:%d)"
,
std
::
forward
<
StrType
>
(
what
),
file
,
sout
<<
string
::
Sprintf
(
"%s
(at
%s:%d)"
,
std
::
forward
<
StrType
>
(
what
),
file
,
line
)
<<
std
::
endl
;
return
sout
.
str
();
...
...
paddle/fluid/pybind/op_function.h
浏览文件 @
672578a7
...
...
@@ -18,9 +18,11 @@
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/variable.h"
...
...
@@ -31,6 +33,63 @@
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
pybind
{
static
inline
std
::
shared_ptr
<
imperative
::
VarBase
>
CastPyHandleToVarBase
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
int
arg_idx
,
const
py
::
handle
&
handle
)
{
PyObject
*
py_obj
=
handle
.
ptr
();
// get underlying PyObject
if
(
!
py_obj
||
py_obj
==
Py_None
)
{
return
nullptr
;
}
try
{
return
py
::
cast
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
(
py
::
handle
(
py_obj
));
}
catch
(
py
::
cast_error
&
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
Py_TYPE
(
py_obj
)
->
tp_name
));
}
}
static
inline
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
CastPyHandleToVarBaseList
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
int
arg_idx
,
const
py
::
handle
&
handle
)
{
PyObject
*
py_obj
=
handle
.
ptr
();
// get underlying PyObject
if
(
!
py_obj
||
py_obj
==
Py_None
)
{
return
{};
}
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
result
;
if
(
PyList_Check
(
py_obj
)
||
PyTuple_Check
(
py_obj
))
{
auto
size
=
PyTuple_Check
(
py_obj
)
?
PyTuple_GET_SIZE
(
py_obj
)
:
PyList_GET_SIZE
(
py_obj
);
for
(
auto
i
=
0
;
i
<
size
;
++
i
)
{
PyObject
*
item
=
PyTuple_Check
(
py_obj
)
?
PyTuple_GET_ITEM
(
py_obj
,
i
)
:
PyList_GET_ITEM
(
py_obj
,
i
);
if
(
!
item
||
item
==
Py_None
)
{
result
.
emplace_back
(
nullptr
);
continue
;
}
try
{
result
.
emplace_back
(
py
::
cast
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
(
py
::
handle
(
item
)));
}
catch
(
py
::
cast_error
&
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of "
"Tensors, but "
"got %s in list (item %d)"
,
op_type
,
arg_name
,
arg_idx
,
Py_TYPE
(
item
)
->
tp_name
,
i
));
}
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
Py_TYPE
(
py_obj
)
->
tp_name
));
}
return
result
;
}
// namespace pybind
static
inline
void
ConstructAttrMapFromPyArgs
(
framework
::
AttributeMap
*
attrs
,
const
py
::
args
&
args
)
{
PADDLE_ENFORCE_EQ
(
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
672578a7
...
...
@@ -116,8 +116,19 @@ const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
const
char
*
ARG_OUT_NUM
=
R"(%sNum)"
;
const
char
*
ARG_OUT_NUM_TYPE
=
R"(size_t )"
;
const
char
*
VAR_TYPE
=
R"(std::shared_ptr<imperative::VarBase>)"
;
const
char
*
VAR_LIST_TYPE
=
R"(std::vector<std::shared_ptr<imperative::VarBase>>)"
;
const
char
*
IN_VAR_TYPE
=
R"(py::handle)"
;
const
char
*
IN_VAR_LIST_TYPE
=
R"(py::handle)"
;
const
char
*
OUT_VAR_TYPE
=
R"(std::shared_ptr<imperative::VarBase>)"
;
const
char
*
OUT_VAR_LIST_TYPE
=
R"(std::vector<std::shared_ptr<imperative::VarBase>>)"
;
const
char
*
CAST_VAR_TEMPLATE
=
R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)"
;
const
char
*
CAST_VAR_LIST_TEMPLATE
=
R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)"
;
const
char
*
ARG_TEMPLATE
=
R"(const %s& %s)"
;
const
char
*
RETURN_TUPLE_TYPE
=
R"(std::tuple<%s>)"
;
...
...
@@ -133,6 +144,7 @@ const char* OP_FUNCTION_TEMPLATE =
R"(
%s %s(%s)
{
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs(&attrs, args);
{
...
...
@@ -164,6 +176,10 @@ static inline bool FindPassingOutsMap(const std::string& op_type,
return
op_passing_outs_map
[
op_type
].
count
(
out_name
);
}
static
inline
std
::
string
TempName
(
const
std
::
string
&
name
)
{
return
name
+
'_'
;
}
static
std
::
tuple
<
std
::
vector
<
std
::
string
>
,
std
::
vector
<
std
::
string
>>
GenerateOpFunctions
(
const
std
::
string
&
module_name
)
{
auto
&
op_info_map
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
map
();
...
...
@@ -187,16 +203,24 @@ GenerateOpFunctions(const std::string& module_name) {
std
::
string
ins_initializer
=
"{"
;
std
::
string
ins_initializer_with_null
=
""
;
std
::
string
py_arg
=
""
;
int
arg_idx
=
0
;
std
::
string
ins_cast_str
=
""
;
for
(
auto
&
input
:
op_proto
->
inputs
())
{
auto
&
in_name
=
input
.
name
();
// skip those dispensable inputs, like ResidualData in conv2d
if
(
input
.
dispensable
()
&&
!
FindInsMap
(
op_type
,
in_name
))
{
continue
;
}
const
auto
in_type
=
input
.
duplicable
()
?
VAR_LIST_TYPE
:
VAR_TYPE
;
auto
input_arg
=
paddle
::
string
::
Sprintf
(
ARG_TEMPLATE
,
in_type
,
in_name
);
const
auto
in_type
=
input
.
duplicable
()
?
IN_VAR_LIST_TYPE
:
IN_VAR_TYPE
;
auto
input_arg
=
paddle
::
string
::
Sprintf
(
ARG_TEMPLATE
,
in_type
,
TempName
(
in_name
));
input_args
+=
input_arg
;
input_args
+=
","
;
const
auto
in_cast_type
=
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
op_type
,
in_name
,
arg_idx
++
,
TempName
(
in_name
));
if
(
input
.
dispensable
())
{
const
auto
in_template
=
input
.
duplicable
()
...
...
@@ -235,7 +259,8 @@ GenerateOpFunctions(const std::string& module_name) {
if
(
output
.
dispensable
()
&&
!
FindOutsMap
(
op_type
,
out_name
))
{
continue
;
}
const
auto
out_type
=
output
.
duplicable
()
?
VAR_LIST_TYPE
:
VAR_TYPE
;
const
auto
out_type
=
output
.
duplicable
()
?
OUT_VAR_LIST_TYPE
:
OUT_VAR_TYPE
;
const
auto
return_template
=
output
.
duplicable
()
?
RETURN_LIST_TEMPLATE
:
RETURN_TEMPLATE
;
if
(
FindPassingOutsMap
(
op_type
,
out_name
))
{
...
...
@@ -309,7 +334,7 @@ GenerateOpFunctions(const std::string& module_name) {
// generate op funtcion body
auto
op_function_str
=
paddle
::
string
::
Sprintf
(
OP_FUNCTION_TEMPLATE
,
return_type
,
func_name
,
function_args
,
outs_initializer
,
ins_initializer
,
ins_cast_str
,
outs_initializer
,
ins_initializer
,
ins_initializer_with_null
+
outs_initializer_with_null
,
op_type
,
return_str
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录