Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
08e81475
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
08e81475
编写于
6月 11, 2021
作者:
W
wanghuancoder
提交者:
GitHub
6月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use PYTHON_C_API in dygraph (#32524)
* use PYTHON_C_API in dygraph, test=develop
上级
022198c5
变更
5
展开全部
隐藏空白更改
内联
并排
Showing
5 changed file
with
958 addition
and
71 deletion
+958
-71
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+25
-28
paddle/fluid/pybind/op_function.h
paddle/fluid/pybind/op_function.h
+856
-0
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+67
-38
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+9
-4
python/paddle/fluid/layers/utils.py
python/paddle/fluid/layers/utils.py
+1
-1
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
08e81475
...
@@ -51,6 +51,8 @@ limitations under the License. */
...
@@ -51,6 +51,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
pybind
{
namespace
pybind
{
PyTypeObject
*
g_varbase_pytype
=
nullptr
;
namespace
py
=
::
pybind11
;
namespace
py
=
::
pybind11
;
class
Layer
:
public
imperative
::
Layer
{
class
Layer
:
public
imperative
::
Layer
{
...
@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
...
@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
}
}
template
<
typename
P
>
template
<
typename
P
>
static
void
VarBaseCopy
(
std
::
shared_ptr
<
imperative
::
VarBase
>
&
src
,
static
void
VarBaseCopy
(
std
::
shared_ptr
<
imperative
::
VarBase
>
&
src
,
// NOLINT
imperative
::
VarBase
&
dst
,
const
P
&
dst_device
,
imperative
::
VarBase
&
dst
,
// NOLINT
const
bool
blocking
)
{
const
P
&
dst_device
,
const
bool
blocking
)
{
if
(
dst
.
SharedVar
()
->
IsEmpty
())
{
if
(
dst
.
SharedVar
()
->
IsEmpty
())
{
VLOG
(
3
)
<<
"deep copy Variable from "
<<
src
->
Name
()
<<
" to "
VLOG
(
3
)
<<
"deep copy Variable from "
<<
src
->
Name
()
<<
" to "
<<
dst
.
Name
();
<<
dst
.
Name
();
...
@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) {
...
@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) {
imperative
::
SetCurrentTracer
(
tracer
);
imperative
::
SetCurrentTracer
(
tracer
);
});
});
py
::
class_
<
imperative
::
VarBase
,
std
::
shared_ptr
<
imperative
::
VarBase
>>
(
py
::
class_
<
imperative
::
VarBase
,
std
::
shared_ptr
<
imperative
::
VarBase
>>
varbase
(
m
,
"VarBase"
,
R"DOC()DOC"
)
m
,
"VarBase"
,
R"DOC()DOC"
);
.
def_static
(
"_alive_vars"
,
&
imperative
::
VarBase
::
AliveVarNames
)
g_varbase_pytype
=
(
PyTypeObject
*
)
varbase
.
ptr
();
// NOLINT
varbase
.
def_static
(
"_alive_vars"
,
&
imperative
::
VarBase
::
AliveVarNames
)
.
def
(
"__init__"
,
.
def
(
"__init__"
,
[](
imperative
::
VarBase
&
self
)
{
[](
imperative
::
VarBase
&
self
)
{
std
::
string
name
=
std
::
string
name
=
...
@@ -1468,28 +1471,22 @@ void BindImperative(py::module *m_ptr) {
...
@@ -1468,28 +1471,22 @@ void BindImperative(py::module *m_ptr) {
&
imperative
::
VarBase
::
SetOverridedStopGradient
)
&
imperative
::
VarBase
::
SetOverridedStopGradient
)
.
def_property
(
"persistable"
,
&
imperative
::
VarBase
::
Persistable
,
.
def_property
(
"persistable"
,
&
imperative
::
VarBase
::
Persistable
,
&
imperative
::
VarBase
::
SetPersistable
)
&
imperative
::
VarBase
::
SetPersistable
)
.
def_property_readonly
(
"shape"
,
.
def_property_readonly
(
[](
imperative
::
VarBase
&
self
)
{
"shape"
,
if
(
self
.
Var
().
IsType
<
framework
::
LoDTensor
>
())
{
[](
imperative
::
VarBase
&
self
)
{
return
framework
::
vectorize
<
int
>
(
if
(
self
.
Var
().
IsType
<
framework
::
LoDTensor
>
())
{
self
.
Var
()
return
framework
::
vectorize
<
int
>
(
.
Get
<
framework
::
LoDTensor
>
()
self
.
Var
().
Get
<
framework
::
LoDTensor
>
().
dims
());
.
dims
());
}
else
if
(
self
.
Var
().
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
self
.
Var
()
return
framework
::
vectorize
<
int
>
(
.
IsType
<
self
.
Var
().
Get
<
framework
::
SelectedRows
>
().
value
().
dims
());
framework
::
SelectedRows
>
())
{
}
else
{
return
framework
::
vectorize
<
int
>
(
VLOG
(
2
)
<<
"It is meaningless to get shape of "
self
.
Var
()
"variable type "
.
Get
<
framework
::
SelectedRows
>
()
<<
GetTypeName
(
self
);
.
value
()
return
std
::
vector
<
int
>
();
.
dims
());
}
}
else
{
})
VLOG
(
2
)
<<
"It is meaningless to get shape of "
"variable type "
<<
GetTypeName
(
self
);
return
std
::
vector
<
int
>
();
}
})
.
def_property_readonly
(
"is_leaf"
,
&
imperative
::
VarBase
::
IsLeaf
,
.
def_property_readonly
(
"is_leaf"
,
&
imperative
::
VarBase
::
IsLeaf
,
R"DOC(
R"DOC(
Whether a Tensor is leaf Tensor.
Whether a Tensor is leaf Tensor.
...
...
paddle/fluid/pybind/op_function.h
浏览文件 @
08e81475
此差异已折叠。
点击以展开。
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
08e81475
...
@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
...
@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const
char
*
OUT_VAR_LIST_TYPE
=
R"(std::vector<std::shared_ptr<imperative::VarBase>>)"
;
const
char
*
OUT_VAR_LIST_TYPE
=
R"(std::vector<std::shared_ptr<imperative::VarBase>>)"
;
const
char
*
CAST_VAR_TEMPLATE
=
R"(
const
char
*
CAST_VAR_TEMPLATE
=
R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s
, %s);)"
;
auto %s = GetVarBaseFromArgs("%s", "%s", args, %d
, %s);)"
;
const
char
*
CAST_VAR_LIST_TEMPLATE
=
R"(
const
char
*
CAST_VAR_LIST_TEMPLATE
=
R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s
, %s);)"
;
auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d
, %s);)"
;
const
char
*
CAST_SIZE_T_TEMPLATE
=
R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)"
;
const
char
*
ARG_TEMPLATE
=
R"(const %s& %s)"
;
const
char
*
ARG_TEMPLATE
=
R"(const %s& %s)"
;
const
char
*
RETURN_TUPLE_TYPE
=
R"(std::tuple<%s>)"
;
const
char
*
RETURN_TUPLE_TYPE
=
R"(std::tuple<%s>)"
;
const
char
*
RETURN_TYPE
=
R"(%s)"
;
const
char
*
RETURN_TUPLE_TEMPLATE
=
R"(std::make_tuple(%s))"
;
const
char
*
RETURN_TUPLE_TEMPLATE
=
R"(std::make_tuple(%s))"
;
const
char
*
RETURN_LIST_TEMPLATE
=
R"(outs["%s"])"
;
const
char
*
RETURN_LIST_TEMPLATE
=
R"(outs["%s"])"
;
const
char
*
RETURN_TEMPLATE
=
R"(outs["%s"][0])"
;
const
char
*
RETURN_TEMPLATE
=
R"(outs["%s"][0])"
;
...
@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";
...
@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";
const
char
*
OP_FUNCTION_TEMPLATE
=
const
char
*
OP_FUNCTION_TEMPLATE
=
R"(
R"(
%s %s(%
s)
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwarg
s)
{
{
%s
PyThreadState *tstate = nullptr;
framework::AttributeMap attrs;
try
ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
{
{
py::gil_scoped_release release;
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s
%s
imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s;
imperative::NameVarBaseMap ins = %s;
%s
%s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return %s;
return %s;
}
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
})"
;
})"
;
const
char
*
PYBIND_ITEM_TEMPLATE
=
R"(
%s.def("%s", &%s);
)"
;
const
char
*
PYBIND_ITEM_TEMPLATE
=
R"(
{"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},
)"
;
// clang-format on
// clang-format on
static
inline
bool
FindInsMap
(
const
std
::
string
&
op_type
,
static
inline
bool
FindInsMap
(
const
std
::
string
&
op_type
,
...
@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody(
...
@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody(
const
auto
in_cast_type
=
const
auto
in_cast_type
=
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
auto
dispensable
=
input
.
dispensable
()
?
"true"
:
"false"
;
auto
dispensable
=
input
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
op_type
,
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
op_type
,
in_name
,
in_name
,
arg_idx
++
,
dispensable
);
arg_idx
++
,
TempName
(
in_name
),
dispensable
);
if
(
input
.
dispensable
())
{
if
(
input
.
dispensable
())
{
const
auto
in_template
=
input
.
duplicable
()
const
auto
in_template
=
input
.
duplicable
()
...
@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody(
...
@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody(
// Generate outs initializer
// Generate outs initializer
std
::
string
outs_initializer
=
"{"
;
std
::
string
outs_initializer
=
"{"
;
std
::
string
outs_initializer_with_null
=
""
;
std
::
string
outs_initializer_with_null
=
""
;
std
::
string
return_type
=
""
;
std
::
string
inplace_mapping_str
=
""
;
std
::
string
inplace_mapping_str
=
""
;
std
::
string
return_str
=
""
;
std
::
string
return_str
=
""
;
...
@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody(
...
@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody(
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
out_name
);
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
out_name
);
outs_initializer
+=
","
;
outs_initializer
+=
","
;
}
}
const
auto
in_cast_type
=
output
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
auto
dispensable
=
output
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
out_name
,
op_type
,
out_name
,
arg_idx
++
,
dispensable
);
}
else
if
(
use_inplace_strategy
&&
inplace_map
.
count
(
out_name
))
{
}
else
if
(
use_inplace_strategy
&&
inplace_map
.
count
(
out_name
))
{
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
inplace_map
[
out_name
],
""
,
inplace_map
[
out_name
],
""
,
...
@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody(
...
@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody(
input_args_num
++
;
input_args_num
++
;
outs_initializer
+=
paddle
::
string
::
Sprintf
(
outs_initializer
+=
paddle
::
string
::
Sprintf
(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE
,
out_name
,
out_num_str
);
OUT_DUPLICABLE_INITIALIZER_TEMPLATE
,
out_name
,
out_num_str
);
auto
dispensable
=
output
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
CAST_SIZE_T_TEMPLATE
,
out_num_str
,
op_type
,
out_num_str
,
arg_idx
++
,
dispensable
);
}
else
{
}
else
{
outs_initializer
+=
outs_initializer
+=
paddle
::
string
::
Sprintf
(
OUT_INITIALIZER_TEMPLATE
,
out_name
);
paddle
::
string
::
Sprintf
(
OUT_INITIALIZER_TEMPLATE
,
out_name
);
...
@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody(
...
@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody(
outs_initializer
+=
","
;
outs_initializer
+=
","
;
}
}
return_type
+=
out_type
;
return_type
+=
","
;
return_str
+=
paddle
::
string
::
Sprintf
(
return_template
,
out_name
);
return_str
+=
paddle
::
string
::
Sprintf
(
return_template
,
out_name
);
return_str
+=
","
;
return_str
+=
","
;
outs_num
+=
1
;
outs_num
+=
1
;
}
}
if
(
outs_initializer
.
back
()
==
','
)
{
if
(
outs_initializer
.
back
()
==
','
)
{
outs_initializer
.
pop_back
();
outs_initializer
.
pop_back
();
return_type
.
pop_back
();
return_str
.
pop_back
();
return_str
.
pop_back
();
}
}
outs_initializer
+=
"}"
;
outs_initializer
+=
"}"
;
...
@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody(
...
@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody(
viwe_input_name
,
viwe_output_name
);
viwe_input_name
,
viwe_output_name
);
}
}
if
(
outs_num
==
0
)
{
if
(
outs_num
==
0
)
{
return_type
=
"void"
;
return_str
=
"Py_None"
;
}
}
else
if
(
outs_num
==
1
)
{
if
(
outs_num
>
1
)
{
return_str
=
"MakeReturnPyObject("
+
return_str
+
")"
;
return_str
=
paddle
::
string
::
Sprintf
(
RETURN_TUPLE_TEMPLATE
,
return_str
);
}
else
{
return_type
=
paddle
::
string
::
Sprintf
(
RETURN_TUPLE_TYPE
,
return_type
);
return_str
=
"MakeReturnPyObject("
+
paddle
::
string
::
Sprintf
(
RETURN_TUPLE_TEMPLATE
,
return_str
)
+
")"
;
}
}
std
::
string
function_args
=
""
;
std
::
string
function_args
=
""
;
if
(
input_args
==
""
)
{
if
(
input_args
==
""
)
{
...
@@ -485,17 +505,17 @@ std::string GenerateOpFunctionsBody(
...
@@ -485,17 +505,17 @@ std::string GenerateOpFunctionsBody(
// generate op funtcion body
// generate op funtcion body
auto
op_function_str
=
paddle
::
string
::
Sprintf
(
auto
op_function_str
=
paddle
::
string
::
Sprintf
(
OP_FUNCTION_TEMPLATE
,
return_type
,
func_name
,
function_args
,
ins_cast_str
,
OP_FUNCTION_TEMPLATE
,
func_name
,
ins_cast_str
,
op_type
,
input_args_num
,
op_type
,
input_args_num
,
inplace_strategy_str
,
out
s_initializer
,
inplace_strategy_str
,
outs_initializer
,
in
s_initializer
,
ins_initializer
,
ins_initializer
_with_null
+
outs_initializer_with_null
+
ins_initializer_with_null
+
outs_initializer_with_null
+
view_strategy_str
,
view_strategy_str
,
op_type
,
inplace_mapping_str
,
return_str
);
op_type
,
inplace_mapping_str
,
return_str
);
return
op_function_str
;
return
op_function_str
;
}
}
static
std
::
tuple
<
std
::
vector
<
std
::
string
>
,
std
::
vector
<
std
::
string
>>
static
std
::
tuple
<
std
::
vector
<
std
::
string
>
,
std
::
vector
<
std
::
string
>>
GenerateOpFunctions
(
const
std
::
string
&
module_name
)
{
GenerateOpFunctions
()
{
auto
&
op_info_map
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
map
();
auto
&
op_info_map
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
map
();
std
::
vector
<
std
::
string
>
op_function_list
,
bind_function_list
;
std
::
vector
<
std
::
string
>
op_function_list
,
bind_function_list
;
...
@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) {
...
@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item
// generate pybind item
auto
bind_function_str
=
paddle
::
string
::
Sprintf
(
auto
bind_function_str
=
paddle
::
string
::
Sprintf
(
PYBIND_ITEM_TEMPLATE
,
module_name
,
op_type
,
func_nam
e
);
PYBIND_ITEM_TEMPLATE
,
op_type
,
func_name
,
op_typ
e
);
op_function_list
.
emplace_back
(
std
::
move
(
op_function_str
));
op_function_list
.
emplace_back
(
std
::
move
(
op_function_str
));
bind_function_list
.
emplace_back
(
std
::
move
(
bind_function_str
));
bind_function_list
.
emplace_back
(
std
::
move
(
bind_function_str
));
...
@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) {
...
@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item
// generate pybind item
auto
inplace_bind_function_str
=
auto
inplace_bind_function_str
=
paddle
::
string
::
Sprintf
(
PYBIND_ITEM_TEMPLATE
,
module_nam
e
,
paddle
::
string
::
Sprintf
(
PYBIND_ITEM_TEMPLATE
,
inplace_op_typ
e
,
inplace_
op_type
,
inplace_func_nam
e
);
inplace_
func_name
,
inplace_op_typ
e
);
op_function_list
.
emplace_back
(
std
::
move
(
inplace_op_function_str
));
op_function_list
.
emplace_back
(
std
::
move
(
inplace_op_function_str
));
bind_function_list
.
emplace_back
(
std
::
move
(
inplace_bind_function_str
));
bind_function_list
.
emplace_back
(
std
::
move
(
inplace_bind_function_str
));
...
@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) {
...
@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) {
ascend_ptr
->
InitGEForUT
();
ascend_ptr
->
InitGEForUT
();
#endif
#endif
std
::
vector
<
std
::
string
>
headers
{
"
\"
paddle/fluid/imperative/tracer.h
\"
"
};
std
::
vector
<
std
::
string
>
headers
{
"
\"
paddle/fluid/imperative/tracer.h
\"
"
,
"
\"
pybind11/detail/common.h
\"
"
,
"<Python.h>"
};
std
::
ofstream
out
(
argv
[
1
],
std
::
ios
::
out
);
std
::
ofstream
out
(
argv
[
1
],
std
::
ios
::
out
);
...
@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) {
...
@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) {
out
<<
"#include "
+
header
+
"
\n
"
;
out
<<
"#include "
+
header
+
"
\n
"
;
}
}
auto
op_funcs
=
GenerateOpFunctions
(
"m"
);
out
<<
"
\n\n
"
;
auto
op_funcs
=
GenerateOpFunctions
();
out
<<
"namespace py = pybind11;"
<<
"
\n
"
;
out
<<
"namespace paddle {
\n
"
out
<<
"namespace paddle {
\n
"
<<
"namespace pybind {
\n\n
"
;
<<
"namespace pybind {
\n\n
"
;
out
<<
"std::atomic<int> VarBaseUniqueNameID{0};
\n
"
;
out
<<
"std::atomic<int> VarBaseUniqueNameID{0};
\n
"
;
out
<<
paddle
::
string
::
join_strings
(
std
::
get
<
0
>
(
op_funcs
),
'\n'
);
out
<<
paddle
::
string
::
join_strings
(
std
::
get
<
0
>
(
op_funcs
),
'\n'
);
out
<<
"
\n\n
"
;
out
<<
"
\n\n
"
;
out
<<
"inline void BindOpFunctions(pybind11::module *module) {
\n
"
out
<<
"static PyMethodDef ExtestMethods[] = {
\n
"
<<
" auto m = module->def_submodule(
\"
ops
\"
);
\n\n
"
;
<<
paddle
::
string
::
join_strings
(
std
::
get
<
1
>
(
op_funcs
),
'\n'
)
<<
"
\n
{nullptr,nullptr,0,nullptr}"
<<
"};
\n\n
"
;
out
<<
paddle
::
string
::
join_strings
(
std
::
get
<
1
>
(
op_funcs
),
'\n'
);
out
<<
"inline void BindOpFunctions(pybind11::module *module) {
\n
"
out
<<
"
\n
"
;
<<
" auto m = module->def_submodule(
\"
ops
\"
);
\n
"
out
<<
"}
\n\n
"
<<
" if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {
\n
"
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
"core.ops failed!
\"
));
\n
"
<<
" }
\n\n
"
<<
" InitOpsAttrTypeMap();"
<<
"}
\n\n
"
<<
"} // namespace pybind
\n
"
<<
"} // namespace pybind
\n
"
<<
"} // namespace paddle
\n
"
;
<<
"} // namespace paddle
\n
"
;
...
...
paddle/fluid/pybind/protobuf.cc
浏览文件 @
08e81475
...
@@ -29,6 +29,9 @@ limitations under the License. */
...
@@ -29,6 +29,9 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
pybind
{
namespace
pybind
{
PyTypeObject
*
g_vartype_pytype
=
nullptr
;
PyTypeObject
*
g_blockdesc_pytype
=
nullptr
;
namespace
pd
=
paddle
::
framework
;
namespace
pd
=
paddle
::
framework
;
template
<
typename
T
>
template
<
typename
T
>
...
@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) {
...
@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) {
}
}
void
BindBlockDesc
(
pybind11
::
module
*
m
)
{
void
BindBlockDesc
(
pybind11
::
module
*
m
)
{
pybind11
::
class_
<
pd
::
BlockDesc
>
(
*
m
,
"BlockDesc"
,
""
)
pybind11
::
class_
<
pd
::
BlockDesc
>
blockdesc
(
*
m
,
"BlockDesc"
,
""
);
.
def_property_readonly
(
"id"
,
&
pd
::
BlockDesc
::
ID
)
g_blockdesc_pytype
=
(
PyTypeObject
*
)
blockdesc
.
ptr
();
// NOLINT
blockdesc
.
def_property_readonly
(
"id"
,
&
pd
::
BlockDesc
::
ID
)
.
def_property_readonly
(
"parent"
,
&
pd
::
BlockDesc
::
Parent
)
.
def_property_readonly
(
"parent"
,
&
pd
::
BlockDesc
::
Parent
)
.
def
(
"get_forward_block_idx"
,
&
pd
::
BlockDesc
::
ForwardBlockID
)
.
def
(
"get_forward_block_idx"
,
&
pd
::
BlockDesc
::
ForwardBlockID
)
.
def
(
"_set_forward_block_idx"
,
&
pd
::
BlockDesc
::
SetForwardBlockID
)
.
def
(
"_set_forward_block_idx"
,
&
pd
::
BlockDesc
::
SetForwardBlockID
)
...
@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) {
...
@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) {
.
def
(
"need_check_feed"
,
&
pd
::
VarDesc
::
NeedCheckFeed
)
.
def
(
"need_check_feed"
,
&
pd
::
VarDesc
::
NeedCheckFeed
)
.
def
(
"set_need_check_feed"
,
&
pd
::
VarDesc
::
SetNeedCheckFeed
);
.
def
(
"set_need_check_feed"
,
&
pd
::
VarDesc
::
SetNeedCheckFeed
);
pybind11
::
enum_
<
pd
::
proto
::
VarType
::
Type
>
(
var_desc
,
"VarType"
,
""
)
pybind11
::
enum_
<
pd
::
proto
::
VarType
::
Type
>
vartype
(
var_desc
,
"VarType"
,
""
);
.
value
(
"BOOL"
,
pd
::
proto
::
VarType
::
BOOL
)
g_vartype_pytype
=
(
PyTypeObject
*
)
vartype
.
ptr
();
// NOLINT
vartype
.
value
(
"BOOL"
,
pd
::
proto
::
VarType
::
BOOL
)
.
value
(
"UINT8"
,
pd
::
proto
::
VarType
::
UINT8
)
.
value
(
"UINT8"
,
pd
::
proto
::
VarType
::
UINT8
)
.
value
(
"INT8"
,
pd
::
proto
::
VarType
::
INT8
)
.
value
(
"INT8"
,
pd
::
proto
::
VarType
::
INT8
)
.
value
(
"INT16"
,
pd
::
proto
::
VarType
::
INT16
)
.
value
(
"INT16"
,
pd
::
proto
::
VarType
::
INT16
)
...
...
python/paddle/fluid/layers/utils.py
浏览文件 @
08e81475
...
@@ -357,7 +357,7 @@ def convert_shape_to_list(shape):
...
@@ -357,7 +357,7 @@ def convert_shape_to_list(shape):
map
(
lambda
x
:
x
.
numpy
()[
0
]
if
isinstance
(
x
,
Variable
)
else
x
,
map
(
lambda
x
:
x
.
numpy
()[
0
]
if
isinstance
(
x
,
Variable
)
else
x
,
shape
))
shape
))
else
:
else
:
shape
=
list
(
shape
.
numpy
().
astype
(
int
)
)
shape
=
shape
.
numpy
().
astype
(
int
).
tolist
(
)
return
shape
return
shape
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录