Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
08e81475
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
08e81475
编写于
6月 11, 2021
作者:
W
wanghuancoder
提交者:
GitHub
6月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use PYTHON_C_API in dygraph (#32524)
* use PYTHON_C_API in dygraph, test=develop
上级
022198c5
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
958 addition
and
71 deletion
+958
-71
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+25
-28
paddle/fluid/pybind/op_function.h
paddle/fluid/pybind/op_function.h
+856
-0
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+67
-38
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+9
-4
python/paddle/fluid/layers/utils.py
python/paddle/fluid/layers/utils.py
+1
-1
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
08e81475
...
...
@@ -51,6 +51,8 @@ limitations under the License. */
namespace
paddle
{
namespace
pybind
{
PyTypeObject
*
g_varbase_pytype
=
nullptr
;
namespace
py
=
::
pybind11
;
class
Layer
:
public
imperative
::
Layer
{
...
...
@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
}
template
<
typename
P
>
static
void
VarBaseCopy
(
std
::
shared_ptr
<
imperative
::
VarBase
>
&
src
,
imperative
::
VarBase
&
dst
,
const
P
&
dst_device
,
const
bool
blocking
)
{
static
void
VarBaseCopy
(
std
::
shared_ptr
<
imperative
::
VarBase
>
&
src
,
// NOLINT
imperative
::
VarBase
&
dst
,
// NOLINT
const
P
&
dst_device
,
const
bool
blocking
)
{
if
(
dst
.
SharedVar
()
->
IsEmpty
())
{
VLOG
(
3
)
<<
"deep copy Variable from "
<<
src
->
Name
()
<<
" to "
<<
dst
.
Name
();
...
...
@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) {
imperative
::
SetCurrentTracer
(
tracer
);
});
py
::
class_
<
imperative
::
VarBase
,
std
::
shared_ptr
<
imperative
::
VarBase
>>
(
m
,
"VarBase"
,
R"DOC()DOC"
)
.
def_static
(
"_alive_vars"
,
&
imperative
::
VarBase
::
AliveVarNames
)
py
::
class_
<
imperative
::
VarBase
,
std
::
shared_ptr
<
imperative
::
VarBase
>>
varbase
(
m
,
"VarBase"
,
R"DOC()DOC"
);
g_varbase_pytype
=
(
PyTypeObject
*
)
varbase
.
ptr
();
// NOLINT
varbase
.
def_static
(
"_alive_vars"
,
&
imperative
::
VarBase
::
AliveVarNames
)
.
def
(
"__init__"
,
[](
imperative
::
VarBase
&
self
)
{
std
::
string
name
=
...
...
@@ -1468,21 +1471,15 @@ void BindImperative(py::module *m_ptr) {
&
imperative
::
VarBase
::
SetOverridedStopGradient
)
.
def_property
(
"persistable"
,
&
imperative
::
VarBase
::
Persistable
,
&
imperative
::
VarBase
::
SetPersistable
)
.
def_property_readonly
(
"shape"
,
.
def_property_readonly
(
"shape"
,
[](
imperative
::
VarBase
&
self
)
{
if
(
self
.
Var
().
IsType
<
framework
::
LoDTensor
>
())
{
return
framework
::
vectorize
<
int
>
(
self
.
Var
()
.
Get
<
framework
::
LoDTensor
>
()
.
dims
());
}
else
if
(
self
.
Var
()
.
IsType
<
framework
::
SelectedRows
>
())
{
self
.
Var
().
Get
<
framework
::
LoDTensor
>
().
dims
());
}
else
if
(
self
.
Var
().
IsType
<
framework
::
SelectedRows
>
())
{
return
framework
::
vectorize
<
int
>
(
self
.
Var
()
.
Get
<
framework
::
SelectedRows
>
()
.
value
()
.
dims
());
self
.
Var
().
Get
<
framework
::
SelectedRows
>
().
value
().
dims
());
}
else
{
VLOG
(
2
)
<<
"It is meaningless to get shape of "
"variable type "
...
...
paddle/fluid/pybind/op_function.h
浏览文件 @
08e81475
...
...
@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
...
...
@@ -34,6 +35,28 @@ namespace py = pybind11;
namespace
paddle
{
namespace
pybind
{
class
OpAttrTypeMap
{
public:
static
OpAttrTypeMap
&
Instance
()
{
static
OpAttrTypeMap
g_op_attr_type_map
;
return
g_op_attr_type_map
;
}
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
paddle
::
framework
::
proto
::
AttrType
>>&
Map
()
{
return
ops_attrtype_map_
;
}
private:
OpAttrTypeMap
()
=
default
;
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
paddle
::
framework
::
proto
::
AttrType
>>
ops_attrtype_map_
;
};
static
inline
std
::
shared_ptr
<
imperative
::
VarBase
>
CastPyHandleToVarBase
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
int
arg_idx
,
const
py
::
handle
&
handle
,
bool
dispensable
=
false
)
{
...
...
@@ -173,6 +196,839 @@ static inline void HandleViewBetweenInputAndOutput(
<<
"), share allocation and inplace version."
;
}
}
extern
PyTypeObject
*
g_varbase_pytype
;
extern
PyTypeObject
*
g_vartype_pytype
;
extern
PyTypeObject
*
g_blockdesc_pytype
;
inline
bool
PyObject_CheckBool
(
PyObject
**
obj
)
{
return
PyBool_Check
(
*
obj
);
}
inline
bool
PyObject_CheckLongOrToLong
(
PyObject
**
obj
)
{
if
((
PyLong_Check
(
*
obj
)
&&
!
PyBool_Check
(
*
obj
))
||
PyObject_IsInstance
(
*
obj
,
(
PyObject
*
)
g_vartype_pytype
)
||
// NOLINT
PyObject_IsInstance
(
*
obj
,
(
PyObject
*
)
g_varbase_pytype
))
{
// NOLINT
return
true
;
}
auto
to
=
PyNumber_Long
(
*
obj
);
if
(
to
)
{
*
obj
=
to
;
return
true
;
}
return
false
;
}
inline
bool
PyObject_CheckFloatOrToFloat
(
PyObject
**
obj
)
{
// sometimes users provide PyLong or numpy.int64 but attr is float
if
(
PyFloat_Check
(
*
obj
)
||
PyLong_Check
(
*
obj
)
||
PyObject_IsInstance
(
*
obj
,
(
PyObject
*
)
g_varbase_pytype
))
{
// NOLINT
return
true
;
}
auto
to
=
PyNumber_Float
(
*
obj
);
if
(
to
)
{
*
obj
=
to
;
return
true
;
}
return
false
;
}
inline
bool
PyObject_CheckString
(
PyObject
*
obj
)
{
return
PyUnicode_Check
(
obj
);
}
static
inline
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
obj
==
Py_None
)
{
attrs
[
key
]
=
false
;
// To be compatible with QA integration testing. Some
// test case pass in None.
}
else
if
(
obj
==
Py_True
)
{
attrs
[
key
]
=
true
;
}
else
if
(
obj
==
Py_False
)
{
attrs
[
key
]
=
false
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"bool, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrInt
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
attrs
[
key
]
=
(
int
)
PyLong_AsLong
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"int, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrLong
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
attrs
[
key
]
=
(
int64_t
)
PyLong_AsLong
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"long, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrFloat
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckFloatOrToFloat
(
&
obj
))
{
attrs
[
key
]
=
(
float
)
PyFloat_AsDouble
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"float, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrString
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckString
(
obj
))
{
Py_ssize_t
size
;
const
char
*
data
;
data
=
PyUnicode_AsUTF8AndSize
(
obj
,
&
size
);
attrs
[
key
]
=
std
::
string
(
data
,
(
size_t
)
size
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"str, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrBooleans
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
bool
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckBool
(
&
item
))
{
value
.
emplace_back
(
PyLong_AsLong
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of bool, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
bool
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckBool
(
&
item
))
{
value
.
emplace_back
(
PyLong_AsLong
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of bool, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list or tuple, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrInts
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
value
.
emplace_back
(
PyLong_AsLong
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
value
.
emplace_back
(
PyLong_AsLong
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
value
.
emplace_back
(
PyLong_AsLong
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list or tuple, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrLongs
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
value
.
emplace_back
(
PyLong_AsLong
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
value
.
emplace_back
(
PyLong_AsLong
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
value
.
emplace_back
(
PyLong_AsLong
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list or tuple, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrFloats
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
value
.
emplace_back
(
PyFloat_AsDouble
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
value
.
emplace_back
(
PyFloat_AsDouble
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
value
.
emplace_back
(
PyFloat_AsDouble
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list or tuple, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrFloat64s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
value
.
emplace_back
(
PyFloat_AsDouble
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
value
.
emplace_back
(
PyFloat_AsDouble
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
value
.
emplace_back
(
PyFloat_AsDouble
(
item
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list or tuple, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrStrings
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
std
::
string
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckString
(
item
))
{
Py_ssize_t
size
;
const
char
*
data
;
data
=
PyUnicode_AsUTF8AndSize
(
item
,
&
size
);
value
.
emplace_back
(
std
::
string
(
data
,
(
size_t
)
size
));
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of str, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
std
::
string
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckString
(
item
))
{
Py_ssize_t
size
;
const
char
*
data
;
data
=
PyUnicode_AsUTF8AndSize
(
item
,
&
size
);
value
.
emplace_back
(
std
::
string
(
data
,
(
size_t
)
size
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list of str, but got %s at pos %d"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
,
// NOLINT
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"list or tuple, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
void
CastPyArg2AttrBlock
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
::
pybind11
::
detail
::
instance
*
inst
=
(
::
pybind11
::
detail
::
instance
*
)
obj
;
// NOLINT
if
(
!
PyObject_IsInstance
((
PyObject
*
)
inst
,
// NOLINT
(
PyObject
*
)
g_blockdesc_pytype
))
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"BlockDesc, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
void
**
vh
=
inst
->
simple_layout
?
inst
->
simple_value_holder
:
&
inst
->
nonsimple
.
values_and_holders
[
0
];
attrs
[
key
]
=
reinterpret_cast
<
paddle
::
framework
::
BlockDesc
*&>
(
vh
[
0
]);
}
static
inline
void
ConstructAttrMapFromPyArgs
(
const
std
::
string
&
op_type
,
PyObject
*
args
,
ssize_t
attr_start
,
ssize_t
attr_end
,
paddle
::
framework
::
AttributeMap
&
attrs
)
{
// NOLINT
PADDLE_ENFORCE_EQ
(
(
attr_end
-
attr_start
)
%
2
,
0
,
platform
::
errors
::
InvalidArgument
(
"The number of arguments for attributes should be even."
));
auto
attr_type_map
=
&
(
OpAttrTypeMap
::
Instance
().
Map
()[
op_type
]);
PyObject
*
obj
=
nullptr
;
for
(
ssize_t
arg_pos
=
attr_start
;
arg_pos
<
attr_end
;
arg_pos
+=
2
)
{
Py_ssize_t
key_len
;
const
char
*
key_ptr
;
obj
=
PyTuple_GET_ITEM
(
args
,
arg_pos
);
if
(
PyObject_CheckString
(
obj
))
{
key_ptr
=
PyUnicode_AsUTF8AndSize
(
obj
,
&
key_len
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be str, but got "
"%s"
,
op_type
,
arg_pos
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
std
::
string
key
(
key_ptr
,
(
size_t
)
key_len
);
auto
iter
=
attr_type_map
->
find
(
key
);
if
(
iter
==
attr_type_map
->
end
())
{
continue
;
}
obj
=
PyTuple_GET_ITEM
(
args
,
arg_pos
+
1
);
switch
(
iter
->
second
)
{
case
paddle
::
framework
::
proto
::
AttrType
::
INT
:
CastPyArg2AttrInt
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
FLOAT
:
CastPyArg2AttrFloat
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
STRING
:
CastPyArg2AttrString
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
INTS
:
CastPyArg2AttrInts
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
FLOATS
:
CastPyArg2AttrFloats
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
STRINGS
:
CastPyArg2AttrStrings
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
BOOLEAN
:
CastPyArg2AttrBoolean
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
BOOLEANS
:
CastPyArg2AttrBooleans
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
LONG
:
CastPyArg2AttrLong
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
LONGS
:
CastPyArg2AttrLongs
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
FLOAT64S
:
CastPyArg2AttrFloat64s
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
case
paddle
::
framework
::
proto
::
AttrType
::
BLOCK
:
CastPyArg2AttrBlock
(
obj
,
attrs
,
key
,
op_type
,
arg_pos
);
break
;
default:
break
;
}
}
}
static
inline
std
::
shared_ptr
<
imperative
::
VarBase
>
GetVarBaseFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
=
false
)
{
::
pybind11
::
detail
::
instance
*
inst
=
(
::
pybind11
::
detail
::
instance
*
)
PyTuple_GET_ITEM
(
args
,
arg_idx
);
if
(
PyTuple_Check
((
PyObject
*
)
inst
))
{
// NOLINT
inst
=
(
::
pybind11
::
detail
::
instance
*
)
PyTuple_GET_ITEM
(
inst
,
0
);
}
if
(
inst
==
nullptr
||
(
PyObject
*
)
inst
==
Py_None
)
{
// NOLINT
if
(
!
dispensable
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be Tensor, but got None"
,
op_type
,
arg_name
,
arg_idx
));
}
return
nullptr
;
}
if
(
!
PyObject_IsInstance
((
PyObject
*
)
inst
,
// NOLINT
(
PyObject
*
)
g_varbase_pytype
))
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)((
PyObject
*
)
inst
)
->
ob_type
)
->
tp_name
));
// NOLINT
}
void
**
vh
=
inst
->
simple_layout
?
inst
->
simple_value_holder
:
&
inst
->
nonsimple
.
values_and_holders
[
0
];
return
reinterpret_cast
<
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>&>
(
vh
[
1
]);
}
static
inline
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
GetVarBaseListFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
=
false
)
{
PyObject
*
list
=
PyTuple_GET_ITEM
(
args
,
arg_idx
);
if
(
list
==
nullptr
)
{
if
(
!
dispensable
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
"None"
,
op_type
,
arg_name
,
arg_idx
));
// NOLINT
}
return
{};
}
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
result
;
if
(
PyList_Check
(
list
))
{
Py_ssize_t
len
=
PyList_Size
(
list
);
if
(
len
==
0
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list"
,
op_type
,
arg_name
,
arg_idx
));
}
::
pybind11
::
detail
::
instance
*
item
=
nullptr
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
(
::
pybind11
::
detail
::
instance
*
)
PyList_GetItem
(
list
,
i
);
if
(
!
PyObject_IsInstance
((
PyObject
*
)
item
,
// NOLINT
(
PyObject
*
)
g_varbase_pytype
))
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)((
PyObject
*
)
item
)
->
ob_type
)
->
tp_name
));
// NOLINT
}
void
**
vh
=
item
->
simple_layout
?
item
->
simple_value_holder
:
&
item
->
nonsimple
.
values_and_holders
[
0
];
result
.
emplace_back
(
reinterpret_cast
<
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>&>
(
vh
[
1
]));
}
}
else
if
(
PyTuple_Check
(
list
))
{
Py_ssize_t
len
=
PyTuple_Size
(
list
);
if
(
len
==
0
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list"
,
op_type
,
arg_name
,
arg_idx
));
}
::
pybind11
::
detail
::
instance
*
item
=
nullptr
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
(
::
pybind11
::
detail
::
instance
*
)
PyTuple_GetItem
(
list
,
i
);
// NOLINT
if
(
!
PyObject_IsInstance
((
PyObject
*
)
item
,
// NOLINT
(
PyObject
*
)
g_varbase_pytype
))
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)((
PyObject
*
)
item
)
->
ob_type
)
->
tp_name
));
// NOLINT
}
void
**
vh
=
item
->
simple_layout
?
item
->
simple_value_holder
:
&
item
->
nonsimple
.
values_and_holders
[
0
];
result
.
emplace_back
(
reinterpret_cast
<
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>&>
(
vh
[
1
]));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)
list
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
result
;
}
static
inline
unsigned
long
GetUnsignedLongFromArgs
(
// NOLINT
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
=
false
)
{
PyObject
*
item
=
PyTuple_GET_ITEM
(
args
,
arg_idx
);
if
(
item
==
nullptr
)
{
if
(
!
dispensable
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be long, but got None"
,
op_type
,
arg_name
,
arg_idx
));
}
return
0
;
}
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
return
PyLong_AsUnsignedLong
(
item
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be "
"long, but got %s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)
item
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
static
inline
PyObject
*
MakeReturnPyObject
(
const
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>&
out
)
{
return
::
pybind11
::
detail
::
type_caster_base
<
imperative
::
VarBase
>::
cast_holder
(
::
pybind11
::
detail
::
holder_helper
<
std
::
shared_ptr
<
imperative
::
VarBase
>>::
get
(
out
),
&
out
)
.
ptr
();
}
static
inline
PyObject
*
MakeReturnPyObject
(
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>&
out
)
{
PyObject
*
result
=
PyList_New
((
Py_ssize_t
)
out
.
size
());
for
(
size_t
i
=
0
;
i
<
out
.
size
();
i
++
)
{
PyList_SET_ITEM
(
result
,
(
Py_ssize_t
)
i
,
::
pybind11
::
detail
::
type_caster_base
<
imperative
::
VarBase
>::
cast_holder
(
::
pybind11
::
detail
::
holder_helper
<
std
::
shared_ptr
<
imperative
::
VarBase
>>::
get
(
out
[
i
]),
&
out
[
i
])
.
ptr
());
// NOLINT
}
return
result
;
}
template
<
typename
Tuple
,
size_t
N
>
struct
TupleVarBasesResult
{
static
void
Run
(
const
Tuple
&
out
,
PyObject
*
result
)
{
TupleVarBasesResult
<
Tuple
,
N
-
1
>::
Run
(
out
,
result
);
PyTuple_SET_ITEM
(
result
,
N
-
1
,
MakeReturnPyObject
(
std
::
get
<
N
-
1
>
(
out
)));
}
};
template
<
typename
Tuple
>
struct
TupleVarBasesResult
<
Tuple
,
1
>
{
static
void
Run
(
const
Tuple
&
out
,
PyObject
*
result
)
{
PyTuple_SET_ITEM
(
result
,
0
,
MakeReturnPyObject
(
std
::
get
<
0
>
(
out
)));
}
};
template
<
typename
...
Args
>
static
inline
PyObject
*
MakeReturnPyObject
(
const
std
::
tuple
<
Args
...
>&
out
)
{
auto
len
=
sizeof
...(
Args
);
PyObject
*
result
=
PyTuple_New
(
len
);
TupleVarBasesResult
<
decltype
(
out
),
sizeof
...(
Args
)
>::
Run
(
out
,
result
);
return
result
;
}
void
InitOpsAttrTypeMap
()
{
auto
op_info_map
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
map
();
for
(
auto
iter
=
op_info_map
.
begin
();
iter
!=
op_info_map
.
end
();
++
iter
)
{
auto
op_proto
=
iter
->
second
.
proto_
;
if
(
op_proto
==
nullptr
)
{
continue
;
}
auto
attrs_proto
=
op_proto
->
attrs
();
for
(
auto
&
attr
:
attrs_proto
)
{
OpAttrTypeMap
::
Instance
().
Map
()[
iter
->
first
][
attr
.
name
()]
=
attr
.
type
();
}
}
}
PyObject
*
EOFExceptionException
=
PyErr_NewException
(
"paddle.EOFException"
,
PyExc_Exception
,
NULL
);
PyObject
*
EnforceNotMetException
=
PyErr_NewException
(
"paddle.EnforceNotMet"
,
PyExc_Exception
,
NULL
);
void
ThrowExceptionToPython
(
std
::
exception_ptr
p
)
{
try
{
if
(
p
)
std
::
rethrow_exception
(
p
);
}
catch
(
const
platform
::
EOFException
&
e
)
{
PyErr_SetString
(
EOFExceptionException
,
e
.
what
());
}
catch
(
const
platform
::
EnforceNotMet
&
e
)
{
switch
(
e
.
code
())
{
case
paddle
::
platform
::
error
::
INVALID_ARGUMENT
:
PyErr_SetString
(
PyExc_ValueError
,
e
.
what
());
break
;
case
paddle
::
platform
::
error
::
NOT_FOUND
:
case
paddle
::
platform
::
error
::
ALREADY_EXISTS
:
case
paddle
::
platform
::
error
::
PRECONDITION_NOT_MET
:
case
paddle
::
platform
::
error
::
PERMISSION_DENIED
:
case
paddle
::
platform
::
error
::
EXECUTION_TIMEOUT
:
case
paddle
::
platform
::
error
::
UNAVAILABLE
:
PyErr_SetString
(
PyExc_RuntimeError
,
e
.
what
());
break
;
case
paddle
::
platform
::
error
::
OUT_OF_RANGE
:
PyErr_SetString
(
PyExc_IndexError
,
e
.
what
());
break
;
case
paddle
::
platform
::
error
::
RESOURCE_EXHAUSTED
:
PyErr_SetString
(
PyExc_MemoryError
,
e
.
what
());
break
;
case
paddle
::
platform
::
error
::
UNIMPLEMENTED
:
PyErr_SetString
(
PyExc_NotImplementedError
,
e
.
what
());
break
;
case
paddle
::
platform
::
error
::
FATAL
:
PyErr_SetString
(
PyExc_SystemError
,
e
.
what
());
break
;
case
paddle
::
platform
::
error
::
EXTERNAL
:
PyErr_SetString
(
PyExc_OSError
,
e
.
what
());
break
;
default:
PyErr_SetString
(
EnforceNotMetException
,
e
.
what
());
break
;
}
}
}
}
// namespace pybind
}
// namespace paddle
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
08e81475
...
...
@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const
char
*
OUT_VAR_LIST_TYPE
=
R"(std::vector<std::shared_ptr<imperative::VarBase>>)"
;
const
char
*
CAST_VAR_TEMPLATE
=
R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s
, %s);)"
;
auto %s = GetVarBaseFromArgs("%s", "%s", args, %d
, %s);)"
;
const
char
*
CAST_VAR_LIST_TEMPLATE
=
R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s
, %s);)"
;
auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d
, %s);)"
;
const
char
*
CAST_SIZE_T_TEMPLATE
=
R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)"
;
const
char
*
ARG_TEMPLATE
=
R"(const %s& %s)"
;
const
char
*
RETURN_TUPLE_TYPE
=
R"(std::tuple<%s>)"
;
const
char
*
RETURN_TYPE
=
R"(%s)"
;
const
char
*
RETURN_TUPLE_TEMPLATE
=
R"(std::make_tuple(%s))"
;
const
char
*
RETURN_LIST_TEMPLATE
=
R"(outs["%s"])"
;
const
char
*
RETURN_TEMPLATE
=
R"(outs["%s"][0])"
;
...
...
@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";
const
char
*
OP_FUNCTION_TEMPLATE
=
R"(
%s %s(%
s)
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwarg
s)
{
PyThreadState *tstate = nullptr;
try
{
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
{
py::gil_scoped_release release;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s
imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s;
%s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return %s;
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
})"
;
const
char
*
PYBIND_ITEM_TEMPLATE
=
R"(
%s.def("%s", &%s);
)"
;
const
char
*
PYBIND_ITEM_TEMPLATE
=
R"(
{"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},
)"
;
// clang-format on
static
inline
bool
FindInsMap
(
const
std
::
string
&
op_type
,
...
...
@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody(
const
auto
in_cast_type
=
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
auto
dispensable
=
input
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
op_type
,
in_name
,
arg_idx
++
,
TempName
(
in_name
),
dispensable
);
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
op_type
,
in_name
,
arg_idx
++
,
dispensable
);
if
(
input
.
dispensable
())
{
const
auto
in_template
=
input
.
duplicable
()
...
...
@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody(
// Generate outs initializer
std
::
string
outs_initializer
=
"{"
;
std
::
string
outs_initializer_with_null
=
""
;
std
::
string
return_type
=
""
;
std
::
string
inplace_mapping_str
=
""
;
std
::
string
return_str
=
""
;
...
...
@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody(
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
out_name
);
outs_initializer
+=
","
;
}
const
auto
in_cast_type
=
output
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
auto
dispensable
=
output
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
out_name
,
op_type
,
out_name
,
arg_idx
++
,
dispensable
);
}
else
if
(
use_inplace_strategy
&&
inplace_map
.
count
(
out_name
))
{
PADDLE_ENFORCE_NE
(
inplace_map
[
out_name
],
""
,
...
...
@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody(
input_args_num
++
;
outs_initializer
+=
paddle
::
string
::
Sprintf
(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE
,
out_name
,
out_num_str
);
auto
dispensable
=
output
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
CAST_SIZE_T_TEMPLATE
,
out_num_str
,
op_type
,
out_num_str
,
arg_idx
++
,
dispensable
);
}
else
{
outs_initializer
+=
paddle
::
string
::
Sprintf
(
OUT_INITIALIZER_TEMPLATE
,
out_name
);
...
...
@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody(
outs_initializer
+=
","
;
}
return_type
+=
out_type
;
return_type
+=
","
;
return_str
+=
paddle
::
string
::
Sprintf
(
return_template
,
out_name
);
return_str
+=
","
;
outs_num
+=
1
;
}
if
(
outs_initializer
.
back
()
==
','
)
{
outs_initializer
.
pop_back
();
return_type
.
pop_back
();
return_str
.
pop_back
();
}
outs_initializer
+=
"}"
;
...
...
@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody(
viwe_input_name
,
viwe_output_name
);
}
if
(
outs_num
==
0
)
{
return_type
=
"void"
;
}
if
(
outs_num
>
1
)
{
return_str
=
paddle
::
string
::
Sprintf
(
RETURN_TUPLE_TEMPLATE
,
return_str
);
return_type
=
paddle
::
string
::
Sprintf
(
RETURN_TUPLE_TYPE
,
return_type
);
return_str
=
"Py_None"
;
}
else
if
(
outs_num
==
1
)
{
return_str
=
"MakeReturnPyObject("
+
return_str
+
")"
;
}
else
{
return_str
=
"MakeReturnPyObject("
+
paddle
::
string
::
Sprintf
(
RETURN_TUPLE_TEMPLATE
,
return_str
)
+
")"
;
}
std
::
string
function_args
=
""
;
if
(
input_args
==
""
)
{
...
...
@@ -485,9 +505,9 @@ std::string GenerateOpFunctionsBody(
// generate op funtcion body
auto
op_function_str
=
paddle
::
string
::
Sprintf
(
OP_FUNCTION_TEMPLATE
,
return_type
,
func_name
,
function_args
,
ins_cast_str
,
op_type
,
input_args_num
,
inplace_strategy_str
,
out
s_initializer
,
ins_initializer
,
ins_initializer
_with_null
+
outs_initializer_with_null
+
OP_FUNCTION_TEMPLATE
,
func_name
,
ins_cast_str
,
op_type
,
input_args_num
,
inplace_strategy_str
,
outs_initializer
,
in
s_initializer
,
ins_initializer_with_null
+
outs_initializer_with_null
+
view_strategy_str
,
op_type
,
inplace_mapping_str
,
return_str
);
...
...
@@ -495,7 +515,7 @@ std::string GenerateOpFunctionsBody(
}
static
std
::
tuple
<
std
::
vector
<
std
::
string
>
,
std
::
vector
<
std
::
string
>>
GenerateOpFunctions
(
const
std
::
string
&
module_name
)
{
GenerateOpFunctions
()
{
auto
&
op_info_map
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
map
();
std
::
vector
<
std
::
string
>
op_function_list
,
bind_function_list
;
...
...
@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item
auto
bind_function_str
=
paddle
::
string
::
Sprintf
(
PYBIND_ITEM_TEMPLATE
,
module_name
,
op_type
,
func_nam
e
);
PYBIND_ITEM_TEMPLATE
,
op_type
,
func_name
,
op_typ
e
);
op_function_list
.
emplace_back
(
std
::
move
(
op_function_str
));
bind_function_list
.
emplace_back
(
std
::
move
(
bind_function_str
));
...
...
@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item
auto
inplace_bind_function_str
=
paddle
::
string
::
Sprintf
(
PYBIND_ITEM_TEMPLATE
,
module_nam
e
,
inplace_
op_type
,
inplace_func_nam
e
);
paddle
::
string
::
Sprintf
(
PYBIND_ITEM_TEMPLATE
,
inplace_op_typ
e
,
inplace_
func_name
,
inplace_op_typ
e
);
op_function_list
.
emplace_back
(
std
::
move
(
inplace_op_function_str
));
bind_function_list
.
emplace_back
(
std
::
move
(
inplace_bind_function_str
));
...
...
@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) {
ascend_ptr
->
InitGEForUT
();
#endif
std
::
vector
<
std
::
string
>
headers
{
"
\"
paddle/fluid/imperative/tracer.h
\"
"
};
std
::
vector
<
std
::
string
>
headers
{
"
\"
paddle/fluid/imperative/tracer.h
\"
"
,
"
\"
pybind11/detail/common.h
\"
"
,
"<Python.h>"
};
std
::
ofstream
out
(
argv
[
1
],
std
::
ios
::
out
);
...
...
@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) {
out
<<
"#include "
+
header
+
"
\n
"
;
}
auto
op_funcs
=
GenerateOpFunctions
(
"m"
);
out
<<
"
\n\n
"
;
auto
op_funcs
=
GenerateOpFunctions
();
out
<<
"namespace py = pybind11;"
<<
"
\n
"
;
out
<<
"namespace paddle {
\n
"
<<
"namespace pybind {
\n\n
"
;
out
<<
"std::atomic<int> VarBaseUniqueNameID{0};
\n
"
;
out
<<
paddle
::
string
::
join_strings
(
std
::
get
<
0
>
(
op_funcs
),
'\n'
);
out
<<
"
\n\n
"
;
out
<<
"inline void BindOpFunctions(pybind11::module *module) {
\n
"
<<
" auto m = module->def_submodule(
\"
ops
\"
);
\n\n
"
;
out
<<
"static PyMethodDef ExtestMethods[] = {
\n
"
<<
paddle
::
string
::
join_strings
(
std
::
get
<
1
>
(
op_funcs
),
'\n'
)
<<
"
\n
{nullptr,nullptr,0,nullptr}"
<<
"};
\n\n
"
;
out
<<
paddle
::
string
::
join_strings
(
std
::
get
<
1
>
(
op_funcs
),
'\n'
);
out
<<
"
\n
"
;
out
<<
"}
\n\n
"
out
<<
"inline void BindOpFunctions(pybind11::module *module) {
\n
"
<<
" auto m = module->def_submodule(
\"
ops
\"
);
\n
"
<<
" if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {
\n
"
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
"core.ops failed!
\"
));
\n
"
<<
" }
\n\n
"
<<
" InitOpsAttrTypeMap();"
<<
"}
\n\n
"
<<
"} // namespace pybind
\n
"
<<
"} // namespace paddle
\n
"
;
...
...
paddle/fluid/pybind/protobuf.cc
浏览文件 @
08e81475
...
...
@@ -29,6 +29,9 @@ limitations under the License. */
namespace
paddle
{
namespace
pybind
{
PyTypeObject
*
g_vartype_pytype
=
nullptr
;
PyTypeObject
*
g_blockdesc_pytype
=
nullptr
;
namespace
pd
=
paddle
::
framework
;
template
<
typename
T
>
...
...
@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) {
}
void
BindBlockDesc
(
pybind11
::
module
*
m
)
{
pybind11
::
class_
<
pd
::
BlockDesc
>
(
*
m
,
"BlockDesc"
,
""
)
.
def_property_readonly
(
"id"
,
&
pd
::
BlockDesc
::
ID
)
pybind11
::
class_
<
pd
::
BlockDesc
>
blockdesc
(
*
m
,
"BlockDesc"
,
""
);
g_blockdesc_pytype
=
(
PyTypeObject
*
)
blockdesc
.
ptr
();
// NOLINT
blockdesc
.
def_property_readonly
(
"id"
,
&
pd
::
BlockDesc
::
ID
)
.
def_property_readonly
(
"parent"
,
&
pd
::
BlockDesc
::
Parent
)
.
def
(
"get_forward_block_idx"
,
&
pd
::
BlockDesc
::
ForwardBlockID
)
.
def
(
"_set_forward_block_idx"
,
&
pd
::
BlockDesc
::
SetForwardBlockID
)
...
...
@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) {
.
def
(
"need_check_feed"
,
&
pd
::
VarDesc
::
NeedCheckFeed
)
.
def
(
"set_need_check_feed"
,
&
pd
::
VarDesc
::
SetNeedCheckFeed
);
pybind11
::
enum_
<
pd
::
proto
::
VarType
::
Type
>
(
var_desc
,
"VarType"
,
""
)
.
value
(
"BOOL"
,
pd
::
proto
::
VarType
::
BOOL
)
pybind11
::
enum_
<
pd
::
proto
::
VarType
::
Type
>
vartype
(
var_desc
,
"VarType"
,
""
);
g_vartype_pytype
=
(
PyTypeObject
*
)
vartype
.
ptr
();
// NOLINT
vartype
.
value
(
"BOOL"
,
pd
::
proto
::
VarType
::
BOOL
)
.
value
(
"UINT8"
,
pd
::
proto
::
VarType
::
UINT8
)
.
value
(
"INT8"
,
pd
::
proto
::
VarType
::
INT8
)
.
value
(
"INT16"
,
pd
::
proto
::
VarType
::
INT16
)
...
...
python/paddle/fluid/layers/utils.py
浏览文件 @
08e81475
...
...
@@ -357,7 +357,7 @@ def convert_shape_to_list(shape):
map
(
lambda
x
:
x
.
numpy
()[
0
]
if
isinstance
(
x
,
Variable
)
else
x
,
shape
))
else
:
shape
=
list
(
shape
.
numpy
().
astype
(
int
)
)
shape
=
shape
.
numpy
().
astype
(
int
).
tolist
(
)
return
shape
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录