Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
62b15566
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
62b15566
编写于
1月 26, 2022
作者:
J
jim19930609
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added python-c code generation for final state Eager Dygraph
上级
ca743508
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
473 addition
and
112 deletion
+473
-112
paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt
...er/api/generated/eager_generated/backwards/CMakeLists.txt
+2
-2
paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt
...ger/api/generated/eager_generated/forwards/CMakeLists.txt
+2
-2
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
+1
-1
paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt
.../auto_code_generator/final_state_generator/CMakeLists.txt
+10
-0
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+56
-20
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+230
-0
paddle/fluid/eager/utils.cc
paddle/fluid/eager/utils.cc
+6
-2
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+2
-2
paddle/fluid/pybind/eager_op_function_generator.cc
paddle/fluid/pybind/eager_op_function_generator.cc
+5
-0
paddle/fluid/pybind/op_function_common.cc
paddle/fluid/pybind/op_function_common.cc
+135
-83
paddle/fluid/pybind/op_function_common.h
paddle/fluid/pybind/op_function_common.h
+24
-0
未找到文件。
paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt
浏览文件 @
62b15566
cc_library
(
scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info
)
#
cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps})
#
add_dependencies(final_dygraph_node eager_final_state_codegen)
cc_library
(
final_dygraph_node SRCS nodes.cc DEPS
${
eager_deps
}
)
add_dependencies
(
final_dygraph_node eager_final_state_codegen
)
paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt
浏览文件 @
62b15566
cc_library
(
eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node
)
#
cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps})
#
add_dependencies(final_dygraph_function eager_final_state_codegen)
cc_library
(
final_dygraph_function SRCS dygraph_functions.cc DEPS
${
eager_deps
}
)
add_dependencies
(
final_dygraph_function eager_final_state_codegen
)
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
浏览文件 @
62b15566
#
add_subdirectory(final_state_generator)
add_subdirectory
(
final_state_generator
)
set
(
EAGER_GENERETOR_DEPS
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag
)
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt
浏览文件 @
62b15566
...
...
@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_nodes_h_path
}
${
nodes_h_path
}
VERBATIM
)
set
(
tmp_python_c_output_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h"
)
set
(
python_c_output_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/pybind/eager_final_state_op_function_impl.h"
)
add_custom_target
(
eager_final_state_python_c_codegen
COMMAND
"
${
PYTHON_EXECUTABLE
}
"
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py"
"--api_yaml_path=
${
api_yaml_path
}
"
"--output_path=
${
tmp_python_c_output_path
}
"
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_python_c_output_path
}
${
python_c_output_path
}
VERBATIM
)
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
62b15566
...
...
@@ -82,6 +82,14 @@ def RemoveConstAndReference(string):
return
ret
def
GetGradNodeName
(
string
):
return
f
"FinalGradNode
{
string
}
"
def
GetForwardFunctionName
(
string
):
return
f
"
{
string
}
_final_state_dygraph_function"
def
GetAutoGradMetaName
(
string
):
return
f
"
{
string
}
_autograd_meta"
...
...
@@ -145,13 +153,13 @@ def ParseYamlArgs(string):
def
ParseYamlReturns
(
string
):
# Example: Tensor, Tensor
# list = [ [ret_type, orig_position], ...]
# list = [ [
"",
ret_type, orig_position], ...]
returns_list
=
[]
returns
=
[
x
.
strip
()
for
x
in
string
.
strip
().
split
(
","
)]
for
i
in
range
(
len
(
returns
)):
ret
=
returns
[
i
]
returns_list
.
append
([
ret
,
i
])
returns_list
.
append
([
""
,
ret
,
i
])
return
returns_list
...
...
@@ -260,8 +268,8 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
assert
orig_attr_pos
==
forward_attr_pos
for
i
in
range
(
len
(
forward_returns_list
)):
orig_return_type
=
orig_forward_returns_list
[
i
][
0
]
orig_return_pos
=
orig_forward_returns_list
[
i
][
1
]
orig_return_type
=
orig_forward_returns_list
[
i
][
1
]
orig_return_pos
=
orig_forward_returns_list
[
i
][
2
]
forward_return_type
=
forward_returns_list
[
i
][
1
]
forward_return_pos
=
forward_returns_list
[
i
][
2
]
...
...
@@ -452,13 +460,14 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
RemoveConstAndReference
(
atype
),
saved_attr_name
,
default_val
)
# End: SetAttributes & Attribute Members
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
NODE_DECLARATION_TEMPLATE
=
"""
class
GradNode
{} : public egr::GradNodeBase {{
class {} : public egr::GradNodeBase {{
public:
GradNode
{}() : egr::GradNodeBase() {{}}
GradNode
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
{}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~
GradNode
{}() override = default;
~{}() override = default;
virtual std::vector<std::vector<egr::EagerTensor>> operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) override;
...
...
@@ -476,7 +485,7 @@ class GradNode{} : public egr::GradNodeBase {{
}};
"""
node_declaration_str
=
NODE_DECLARATION_TEMPLATE
.
format
(
forward_op_name
,
forward_op_name
,
forward_op_name
,
forward_op
_name
,
grad_node_name
,
grad_node_name
,
grad_node_name
,
grad_node
_name
,
set_tensor_wrapper_methods_str
,
set_attribute_methods_str
,
tensor_wrapper_members_str
,
attribute_members_str
)
...
...
@@ -503,8 +512,13 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->
{
tensor_wrapper_name
}
, nullptr) )"
for
_
,
(
_
,
fwd_position
,
for
_
,
(
ttype
,
fwd_position
,
grad_api_position
)
in
backward_grad_input_map
.
items
():
if
IsPlainTensorType
(
ttype
):
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( grads[
{
fwd_position
}
][0] )"
else
:
assert
IsVectorTensorType
(
ttype
)
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( grads[
{
fwd_position
}
] )"
...
...
@@ -531,8 +545,9 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
returns_str
+=
f
"returns[
{
fwd_position
}
] = egr::EagerUtils::CreateEagerTensorFromTensor( grad_api_returns[
{
grad_api_position
}
] );
\n
"
returns_str
+=
f
"return returns;
\n
"
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
FUNCTION_TEMPLATE
=
"""
std::vector<std::vector<egr::EagerTensor>>
GradNode
{}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
std::vector<std::vector<egr::EagerTensor>> {}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
// Call grad_api function
auto grad_api_returns = paddle::experimental::{}({});
{}
...
...
@@ -540,7 +555,7 @@ std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vec
"""
node_definition_str
=
FUNCTION_TEMPLATE
.
format
(
fwd_api
_name
,
bwd_api_name
,
grad_api_args_str
,
returns_str
)
grad_node
_name
,
bwd_api_name
,
grad_api_args_str
,
returns_str
)
return
node_definition_str
...
...
@@ -610,7 +625,8 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# Node Construction
num_bwd_inputs
=
len
(
backward_grad_input_map
.
keys
())
num_bwd_outputs
=
len
(
backward_grad_output_map
.
keys
())
node_construction_str
=
f
" auto grad_node = std::make_shared<GradNode
{
fwd_api_name
}
>(
{
num_bwd_inputs
}
,
{
num_bwd_outputs
}
);"
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
node_construction_str
=
f
" auto grad_node = std::make_shared<
{
grad_node_name
}
>(
{
num_bwd_inputs
}
,
{
num_bwd_outputs
}
);"
# SetAttributes
set_attributes_list
=
[]
...
...
@@ -786,7 +802,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
backward_grad_output_map
,
backward_attrs_list
)
FORWARD_FUNCTION_TEMPLATE
=
"""
{} {}
_dygraph_function
({}) {{
{} {}({}) {{
// Forward API Call
{}
...
...
@@ -799,15 +815,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
}}
"""
forward_function_name
=
GetForwardFunctionName
(
fwd_api_name
)
forward_function_str
=
FORWARD_FUNCTION_TEMPLATE
.
format
(
returns_type_str
,
fwd_api_name
,
inputs_args_str
,
forward_call_str
,
returns_str
,
node_creation_str
)
forward_function_declaration_str
=
f
"
{
returns_type_str
}
{
fwd_api_name
}
_dygraph_function(
{
inputs_args_str
}
);"
returns_type_str
,
forward_function_name
,
inputs_args_str
,
forward_call_str
,
returns_str
,
node_creation_str
)
forward_function_declaration_str
=
f
"
{
returns_type_str
}
{
forward_function_name
}
(
{
inputs_args_str
}
);"
return
forward_function_str
,
forward_function_declaration_str
def
FakeMatmulGradAPI
():
fake_matmul_grad_str
=
"""
namespace paddle {
namespace experimental {
std::vector<std::vector<Tensor>> matmul_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
bool transpose_x,
bool transpose_y) {
std::vector<std::vector<Tensor>> ret;
return ret;
}
}
}
"""
return
fake_matmul_grad_str
def
GenerateNodeCCFile
(
filepath
,
node_definition_str
):
file_contents
=
"""
#include "glog/logging.h"
...
...
@@ -819,6 +854,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
"""
file_contents
+=
FakeMatmulGradAPI
()
file_contents
+=
node_definition_str
with
open
(
filepath
,
'a'
)
as
f
:
f
.
write
(
file_contents
)
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
0 → 100644
浏览文件 @
62b15566
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
from
eager_gen
import
ReadFwdFile
,
GetForwardFunctionName
,
ParseYamlForward
,
DetermineForwardPositionMap
atype_to_parsing_function
=
{
"bool"
:
"CastPyArg2Boolean"
,
"int"
:
"CastPyArg2Int"
,
"long"
:
"CastPyArg2Long"
,
"float"
:
"CastPyArg2Float"
,
"string"
:
"CastPyArg2String"
,
"bool[]"
:
"CastPyArg2Booleans"
,
"int[]"
:
"CastPyArg2Ints"
,
"long[]"
:
"CastPyArg2Longs"
,
"float[]"
:
"CastPyArg2Floats"
,
"double[]"
:
"CastPyArg2Float64s"
,
"string[]"
:
"CastPyArg2Strings"
}
atype_to_cxx_type
=
{
"bool"
:
"bool"
,
"int"
:
"int"
,
"long"
:
"long"
,
"float"
:
"float"
,
"string"
:
"std::string"
,
"bool[]"
:
"std::vector<bool>"
,
"int[]"
:
"std::vector<int>"
,
"long[]"
:
"std::vector<long>"
,
"float[]"
:
"std::vector<float>"
,
"double[]"
:
"std::vector<double>"
,
"string[]"
:
"std::vector<std::string>"
}
def
ParseArguments
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Eager Code Generator Args Parser'
)
parser
.
add_argument
(
'--api_yaml_path'
,
type
=
str
)
parser
.
add_argument
(
'--output_path'
,
type
=
str
)
args
=
parser
.
parse_args
()
return
args
def
GetCxxType
(
atype
):
if
atype
not
in
atype_to_cxx_type
.
keys
():
assert
False
return
atype_to_cxx_type
[
atype
]
def
FindParsingFunctionFromAttributeType
(
atype
):
if
atype
not
in
atype_to_parsing_function
.
keys
():
assert
False
return
atype_to_parsing_function
[
atype
]
def
GeneratePythonCFunction
(
fwd_api_name
,
forward_inputs_position_map
,
forward_attrs_list
,
forward_outputs_position_map
):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Get EagerTensor from args
# Get dygraph function call args
num_args
=
len
(
forward_inputs_position_map
.
keys
())
+
len
(
forward_attrs_list
)
num_input_tensors
=
len
(
forward_inputs_position_map
.
keys
())
dygraph_function_call_list
=
[
""
for
i
in
range
(
num_args
)]
get_eager_tensor_str
=
""
for
name
,
(
ttype
,
pos
)
in
forward_inputs_position_map
.
items
():
get_eager_tensor_str
+=
f
" auto&
{
name
}
= GetEagerTensorFromArgs(
\"
{
fwd_api_name
}
\"
,
\"
{
name
}
\"
, args,
{
pos
}
, false);
\n
"
dygraph_function_call_list
[
pos
]
=
f
"
{
name
}
"
parse_attributes_str
=
" paddle::framework::AttributeMap attrs;
\n
"
# Get Attributes
for
name
,
atype
,
_
,
pos
in
forward_attrs_list
:
parsing_function
=
FindParsingFunctionFromAttributeType
(
atype
)
cxx_type
=
GetCxxType
(
atype
)
key
=
f
"
{
name
}
"
parse_attributes_str
+=
f
" PyObject*
{
name
}
_obj = PyTuple_GET_ITEM(args,
{
pos
}
);
\n
"
parse_attributes_str
+=
f
"
{
cxx_type
}
{
name
}
=
{
parsing_function
}
(
{
name
}
_obj,
\"
{
fwd_api_name
}
\"
,
{
pos
}
);
\n
"
dygraph_function_call_list
[
pos
]
=
f
"
{
name
}
"
dygraph_function_call_str
=
","
.
join
(
dygraph_function_call_list
)
PYTHON_C_FUNCTION_TEMPLATE
=
"""
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
{{
PyThreadState *tstate = nullptr;
try
{{
// Get EagerTensors from args
{}
// Parse Attributes
{}
tstate = PyEval_SaveThread();
auto out = {}({});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
}}
catch(...) {{
if (tstate) {{
PyEval_RestoreThread(tstate);
}}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
python_c_function_str
=
PYTHON_C_FUNCTION_TEMPLATE
.
format
(
fwd_api_name
,
get_eager_tensor_str
,
parse_attributes_str
,
GetForwardFunctionName
(
fwd_api_name
),
dygraph_function_call_str
)
python_c_function_reg_str
=
f
"{{
\"
final_state_
{
fwd_api_name
}
\"
, (PyCFunction)(void(*)(void))eager_final_state_api_
{
fwd_api_name
}
, METH_VARARGS | METH_KEYWORDS,
\"
C++ interface function for
{
fwd_api_name
}
in dygraph.
\"
}}"
return
python_c_function_str
,
python_c_function_reg_str
def
GeneratePythonCWrappers
(
python_c_function_str
,
python_c_function_reg_str
):
PYTHON_C_WRAPPER_TEMPLATE
=
"""
#pragma once
#include "pybind11/detail/common.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h"
#include <Python.h>
namespace paddle {{
namespace pybind {{
{}
static PyMethodDef EagerFinalStateMethods[] = {{
{}
}};
}} // namespace pybind
}} // namespace paddle
"""
python_c_str
=
PYTHON_C_WRAPPER_TEMPLATE
.
format
(
python_c_function_str
,
python_c_function_reg_str
)
return
python_c_str
def
GeneratePythonCFile
(
filepath
,
python_c_str
):
with
open
(
filepath
,
'a'
)
as
f
:
f
.
write
(
python_c_str
)
if
__name__
==
"__main__"
:
args
=
ParseArguments
()
api_yaml_path
=
args
.
api_yaml_path
fwd_api_list
=
ReadFwdFile
(
api_yaml_path
)
python_c_function_list
=
[]
python_c_function_reg_list
=
[]
for
fwd_api
in
fwd_api_list
:
# We only generate Ops with grad
if
'backward'
not
in
fwd_api
.
keys
():
continue
assert
'api'
in
fwd_api
.
keys
()
assert
'args'
in
fwd_api
.
keys
()
assert
'output'
in
fwd_api
.
keys
()
assert
'backward'
in
fwd_api
.
keys
()
fwd_api_name
=
fwd_api
[
'api'
]
fwd_args_str
=
fwd_api
[
'args'
]
fwd_returns_str
=
fwd_api
[
'output'
]
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list
,
forward_attrs_list
,
forward_returns_list
=
ParseYamlForward
(
fwd_args_str
,
fwd_returns_str
)
print
(
"Parsed Original Forward Inputs List: "
,
forward_inputs_list
)
print
(
"Prased Original Forward Attrs List: "
,
forward_attrs_list
)
print
(
"Parsed Original Forward Returns List: "
,
forward_returns_list
)
forward_inputs_position_map
,
forward_outputs_position_map
=
DetermineForwardPositionMap
(
forward_inputs_list
,
forward_returns_list
)
print
(
"Generated Forward Input Position Map: "
,
forward_inputs_position_map
)
print
(
"Generated Forward Output Position Map: "
,
forward_outputs_position_map
)
python_c_function_str
,
python_c_function_reg_str
=
GeneratePythonCFunction
(
fwd_api_name
,
forward_inputs_position_map
,
forward_attrs_list
,
forward_outputs_position_map
)
python_c_function_list
.
append
(
python_c_function_str
)
python_c_function_reg_list
.
append
(
python_c_function_reg_str
)
print
(
"Generated Python-C Function: "
,
python_c_function_str
)
python_c_functions_str
=
"
\n
"
.
join
(
python_c_function_list
)
python_c_functions_reg_str
=
",
\n
"
.
join
(
python_c_function_reg_list
)
python_c_str
=
GeneratePythonCWrappers
(
python_c_functions_str
,
python_c_functions_reg_str
)
print
(
"Generated Python-C Codes: "
,
python_c_str
)
output_path
=
args
.
output_path
for
path
in
[
output_path
]:
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
GeneratePythonCFile
(
output_path
,
python_c_str
)
paddle/fluid/eager/utils.cc
浏览文件 @
62b15566
...
...
@@ -288,7 +288,9 @@ void EagerUtils::CheckAndRetainGrad(
paddle
::
experimental
::
Tensor
EagerUtils
::
SyncToPtenTensors
(
const
egr
::
EagerTensor
&
tensor
)
{
if
(
!
tensor
.
initialized
())
{
const_cast
<
EagerTensor
*>
(
&
tensor
)
->
SyncToTensor
();
}
return
*
tensor
.
Tensor
().
get
();
}
...
...
@@ -298,7 +300,9 @@ std::vector<paddle::experimental::Tensor> EagerUtils::SyncToPtenTensors(
size_t
num
=
tensors
.
size
();
res
.
reserve
(
num
);
for
(
size_t
i
=
0
;
i
<
num
;
i
++
)
{
if
(
!
tensors
[
i
].
initialized
())
{
const_cast
<
EagerTensor
*>
(
&
(
tensors
[
i
]))
->
SyncToTensor
();
}
res
.
push_back
(
*
tensors
[
i
].
Tensor
().
get
());
}
return
res
;
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
62b15566
...
...
@@ -151,7 +151,7 @@ if(WITH_PYTHON)
set
(
tmp_eager_impl_file
${
eager_impl_file
}
.tmp
)
set
(
OP_IMPL_DEPS op_function_generator
)
set
(
EAGER_OP_IMPL_DEPS eager_op_function_generator
)
set
(
EAGER_OP_IMPL_DEPS eager_op_function_generator
eager_final_state_python_c_codegen
)
if
(
WIN32
)
if
(
"
${
CMAKE_GENERATOR
}
"
STREQUAL
"Ninja"
)
...
...
@@ -275,7 +275,7 @@ if(WITH_PYTHON)
if
(
NOT ON_INFER
)
cc_library
(
paddle_eager
SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common dygraph_function dygraph_node accumulation_node global_utils utils python
)
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common
final_dygraph_function final_dygraph_node
dygraph_function dygraph_node accumulation_node global_utils utils python
)
add_dependencies
(
paddle_eager eager_codegen
)
add_dependencies
(
paddle_eager eager_op_function_generator_cmd
)
list
(
APPEND PYBIND_DEPS paddle_eager
)
...
...
paddle/fluid/pybind/eager_op_function_generator.cc
浏览文件 @
62b15566
...
...
@@ -393,6 +393,7 @@ int main(int argc, char* argv[]) {
std
::
vector
<
std
::
string
>
headers
{
"
\"
pybind11/detail/common.h
\"
"
,
"
\"
paddle/fluid/pybind/eager_final_state_op_function_impl.h
\"
"
,
"
\"
paddle/fluid/pybind/op_function_common.h
\"
"
,
"
\"
paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h
\"
"
,
...
...
@@ -441,6 +442,10 @@ int main(int argc, char* argv[]) {
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
"core.eager.ops failed!
\"
));
\n
"
<<
" }
\n\n
"
<<
" if (PyModule_AddFunctions(m.ptr(), EagerFinalStateMethods) < 0) {
\n
"
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
"core.eager.ops failed!
\"
));
\n
"
<<
" }
\n\n
"
<<
"}
\n\n
"
<<
"} // namespace pybind
\n
"
<<
"} // namespace paddle
\n
"
;
...
...
paddle/fluid/pybind/op_function_common.cc
浏览文件 @
62b15566
...
...
@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool
PyObject_CheckString
(
PyObject
*
obj
)
{
return
PyUnicode_Check
(
obj
);
}
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
bool
CastPyArg2Boolean
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
obj
==
Py_None
)
{
attrs
[
key
]
=
false
;
// To be compatible with QA integration testing. Some
return
false
;
// To be compatible with QA integration testing. Some
// test case pass in None.
}
else
if
(
obj
==
Py_True
)
{
attrs
[
key
]
=
true
;
return
true
;
}
else
if
(
obj
==
Py_False
)
{
attrs
[
key
]
=
false
;
return
false
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -118,14 +116,20 @@ void CastPyArg2AttrBoolean(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
false
;
}
void
CastPyArg2Attr
Int
(
PyObject
*
obj
,
void
CastPyArg2Attr
Boolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Boolean
(
obj
,
op_type
,
arg_pos
);
}
int
CastPyArg2Int
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
attrs
[
key
]
=
(
int
)
PyLong_AsLong
(
obj
);
// NOLINT
return
(
int
)
PyLong_AsLong
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -133,14 +137,21 @@ void CastPyArg2AttrInt(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
0
;
}
void
CastPyArg2Attr
Long
(
PyObject
*
obj
,
void
CastPyArg2Attr
Int
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Int
(
obj
,
op_type
,
arg_pos
);
}
int64_t
CastPyArg2Long
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
attrs
[
key
]
=
(
int64_t
)
PyLong_AsLong
(
obj
);
// NOLINT
return
(
int64_t
)
PyLong_AsLong
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -148,14 +159,21 @@ void CastPyArg2AttrLong(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
0
;
}
void
CastPyArg2Attr
Float
(
PyObject
*
obj
,
void
CastPyArg2Attr
Long
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Long
(
obj
,
op_type
,
arg_pos
);
}
float
CastPyArg2Float
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckFloatOrToFloat
(
&
obj
))
{
attrs
[
key
]
=
(
float
)
PyFloat_AsDouble
(
obj
);
// NOLINT
return
(
float
)
PyFloat_AsDouble
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -163,17 +181,24 @@ void CastPyArg2AttrFloat(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
0.0
;
}
void
CastPyArg2Attr
String
(
PyObject
*
obj
,
void
CastPyArg2Attr
Float
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Float
(
obj
,
op_type
,
arg_pos
);
}
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckString
(
obj
))
{
Py_ssize_t
size
;
const
char
*
data
;
data
=
PyUnicode_AsUTF8AndSize
(
obj
,
&
size
);
attrs
[
key
]
=
std
::
string
(
data
,
(
size_t
)
size
);
// NOLINT
return
std
::
string
(
data
,
(
size_t
)
size
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
""
;
}
void
CastPyArg2Attr
Booleans
(
PyObject
*
obj
,
void
CastPyArg2Attr
String
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2String
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
bool
>
CastPyArg2Booleans
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
bool
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
bool
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckBool
(
&
item
))
{
...
...
@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
bool
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckBool
(
&
item
))
{
...
...
@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2Attr
Int
s
(
PyObject
*
obj
,
void
CastPyArg2Attr
Boolean
s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Booleans
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
int
>
CastPyArg2Ints
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
int
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2Attr
Long
s
(
PyObject
*
obj
,
void
CastPyArg2Attr
Int
s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Ints
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
int64_t
>
CastPyArg2Longs
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
int64_t
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2Attr
Float
s
(
PyObject
*
obj
,
void
CastPyArg2Attr
Long
s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Longs
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
float
>
CastPyArg2Floats
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
float
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2AttrFloat
64
s
(
PyObject
*
obj
,
void
CastPyArg2AttrFloats
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Floats
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
double
>
CastPyArg2Float64s
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
double
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2Attr
String
s
(
PyObject
*
obj
,
void
CastPyArg2Attr
Float64
s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Float64s
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
std
::
string
>
CastPyArg2Strings
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
std
::
string
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
std
::
string
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckString
(
item
))
{
...
...
@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
std
::
string
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckString
(
item
))
{
...
...
@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2AttrStrings
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Strings
(
obj
,
op_type
,
arg_pos
);
}
void
CastPyArg2AttrBlock
(
PyObject
*
obj
,
...
...
paddle/fluid/pybind/op_function_common.h
浏览文件 @
62b15566
...
...
@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj);
bool
PyObject_CheckString
(
PyObject
*
obj
);
bool
CastPyArg2Boolean
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
int
CastPyArg2Int
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
int64_t
CastPyArg2Long
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
float
CastPyArg2Float
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
bool
>
CastPyArg2Booleans
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
int
>
CastPyArg2Ints
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
int64_t
>
CastPyArg2Longs
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
float
>
CastPyArg2Floats
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
double
>
CastPyArg2Float64s
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
std
::
string
>
CastPyArg2Strings
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录