Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
62b15566
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
62b15566
编写于
1月 26, 2022
作者:
J
jim19930609
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added python-c code generation for final state Eager Dygraph
上级
ca743508
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
473 addition
and
112 deletion
+473
-112
paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt
...er/api/generated/eager_generated/backwards/CMakeLists.txt
+2
-2
paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt
...ger/api/generated/eager_generated/forwards/CMakeLists.txt
+2
-2
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
+1
-1
paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt
.../auto_code_generator/final_state_generator/CMakeLists.txt
+10
-0
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+56
-20
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+230
-0
paddle/fluid/eager/utils.cc
paddle/fluid/eager/utils.cc
+6
-2
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+2
-2
paddle/fluid/pybind/eager_op_function_generator.cc
paddle/fluid/pybind/eager_op_function_generator.cc
+5
-0
paddle/fluid/pybind/op_function_common.cc
paddle/fluid/pybind/op_function_common.cc
+135
-83
paddle/fluid/pybind/op_function_common.h
paddle/fluid/pybind/op_function_common.h
+24
-0
未找到文件。
paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt
浏览文件 @
62b15566
cc_library
(
scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info
)
#
cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps})
#
add_dependencies(final_dygraph_node eager_final_state_codegen)
cc_library
(
final_dygraph_node SRCS nodes.cc DEPS
${
eager_deps
}
)
add_dependencies
(
final_dygraph_node eager_final_state_codegen
)
paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt
浏览文件 @
62b15566
cc_library
(
eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node
)
#
cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps})
#
add_dependencies(final_dygraph_function eager_final_state_codegen)
cc_library
(
final_dygraph_function SRCS dygraph_functions.cc DEPS
${
eager_deps
}
)
add_dependencies
(
final_dygraph_function eager_final_state_codegen
)
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
浏览文件 @
62b15566
#
add_subdirectory(final_state_generator)
add_subdirectory
(
final_state_generator
)
set
(
EAGER_GENERETOR_DEPS
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag
)
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt
浏览文件 @
62b15566
...
...
@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_nodes_h_path
}
${
nodes_h_path
}
VERBATIM
)
set
(
tmp_python_c_output_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h"
)
set
(
python_c_output_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/pybind/eager_final_state_op_function_impl.h"
)
add_custom_target
(
eager_final_state_python_c_codegen
COMMAND
"
${
PYTHON_EXECUTABLE
}
"
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py"
"--api_yaml_path=
${
api_yaml_path
}
"
"--output_path=
${
tmp_python_c_output_path
}
"
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_python_c_output_path
}
${
python_c_output_path
}
VERBATIM
)
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
62b15566
...
...
@@ -82,6 +82,14 @@ def RemoveConstAndReference(string):
return
ret
def
GetGradNodeName
(
string
):
return
f
"FinalGradNode
{
string
}
"
def
GetForwardFunctionName
(
string
):
return
f
"
{
string
}
_final_state_dygraph_function"
def
GetAutoGradMetaName
(
string
):
return
f
"
{
string
}
_autograd_meta"
...
...
@@ -145,13 +153,13 @@ def ParseYamlArgs(string):
def
ParseYamlReturns
(
string
):
# Example: Tensor, Tensor
# list = [ [ret_type, orig_position], ...]
# list = [ [
"",
ret_type, orig_position], ...]
returns_list
=
[]
returns
=
[
x
.
strip
()
for
x
in
string
.
strip
().
split
(
","
)]
for
i
in
range
(
len
(
returns
)):
ret
=
returns
[
i
]
returns_list
.
append
([
ret
,
i
])
returns_list
.
append
([
""
,
ret
,
i
])
return
returns_list
...
...
@@ -260,8 +268,8 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
assert
orig_attr_pos
==
forward_attr_pos
for
i
in
range
(
len
(
forward_returns_list
)):
orig_return_type
=
orig_forward_returns_list
[
i
][
0
]
orig_return_pos
=
orig_forward_returns_list
[
i
][
1
]
orig_return_type
=
orig_forward_returns_list
[
i
][
1
]
orig_return_pos
=
orig_forward_returns_list
[
i
][
2
]
forward_return_type
=
forward_returns_list
[
i
][
1
]
forward_return_pos
=
forward_returns_list
[
i
][
2
]
...
...
@@ -452,13 +460,14 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
RemoveConstAndReference
(
atype
),
saved_attr_name
,
default_val
)
# End: SetAttributes & Attribute Members
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
NODE_DECLARATION_TEMPLATE
=
"""
class
GradNode
{} : public egr::GradNodeBase {{
class {} : public egr::GradNodeBase {{
public:
GradNode
{}() : egr::GradNodeBase() {{}}
GradNode
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
{}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~
GradNode
{}() override = default;
~{}() override = default;
virtual std::vector<std::vector<egr::EagerTensor>> operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) override;
...
...
@@ -476,7 +485,7 @@ class GradNode{} : public egr::GradNodeBase {{
}};
"""
node_declaration_str
=
NODE_DECLARATION_TEMPLATE
.
format
(
forward_op_name
,
forward_op_name
,
forward_op_name
,
forward_op
_name
,
grad_node_name
,
grad_node_name
,
grad_node_name
,
grad_node
_name
,
set_tensor_wrapper_methods_str
,
set_attribute_methods_str
,
tensor_wrapper_members_str
,
attribute_members_str
)
...
...
@@ -503,8 +512,13 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->
{
tensor_wrapper_name
}
, nullptr) )"
for
_
,
(
_
,
fwd_position
,
for
_
,
(
ttype
,
fwd_position
,
grad_api_position
)
in
backward_grad_input_map
.
items
():
if
IsPlainTensorType
(
ttype
):
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( grads[
{
fwd_position
}
][0] )"
else
:
assert
IsVectorTensorType
(
ttype
)
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( grads[
{
fwd_position
}
] )"
...
...
@@ -531,8 +545,9 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
returns_str
+=
f
"returns[
{
fwd_position
}
] = egr::EagerUtils::CreateEagerTensorFromTensor( grad_api_returns[
{
grad_api_position
}
] );
\n
"
returns_str
+=
f
"return returns;
\n
"
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
FUNCTION_TEMPLATE
=
"""
std::vector<std::vector<egr::EagerTensor>>
GradNode
{}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
std::vector<std::vector<egr::EagerTensor>> {}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
// Call grad_api function
auto grad_api_returns = paddle::experimental::{}({});
{}
...
...
@@ -540,7 +555,7 @@ std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vec
"""
node_definition_str
=
FUNCTION_TEMPLATE
.
format
(
fwd_api
_name
,
bwd_api_name
,
grad_api_args_str
,
returns_str
)
grad_node
_name
,
bwd_api_name
,
grad_api_args_str
,
returns_str
)
return
node_definition_str
...
...
@@ -610,7 +625,8 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# Node Construction
num_bwd_inputs
=
len
(
backward_grad_input_map
.
keys
())
num_bwd_outputs
=
len
(
backward_grad_output_map
.
keys
())
node_construction_str
=
f
" auto grad_node = std::make_shared<GradNode
{
fwd_api_name
}
>(
{
num_bwd_inputs
}
,
{
num_bwd_outputs
}
);"
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
node_construction_str
=
f
" auto grad_node = std::make_shared<
{
grad_node_name
}
>(
{
num_bwd_inputs
}
,
{
num_bwd_outputs
}
);"
# SetAttributes
set_attributes_list
=
[]
...
...
@@ -786,7 +802,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
backward_grad_output_map
,
backward_attrs_list
)
FORWARD_FUNCTION_TEMPLATE
=
"""
{} {}
_dygraph_function
({}) {{
{} {}({}) {{
// Forward API Call
{}
...
...
@@ -799,15 +815,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
}}
"""
forward_function_name
=
GetForwardFunctionName
(
fwd_api_name
)
forward_function_str
=
FORWARD_FUNCTION_TEMPLATE
.
format
(
returns_type_str
,
fwd_api_name
,
inputs_args_str
,
forward_call_str
,
returns_str
,
node_creation_str
)
forward_function_declaration_str
=
f
"
{
returns_type_str
}
{
fwd_api_name
}
_dygraph_function(
{
inputs_args_str
}
);"
returns_type_str
,
forward_function_name
,
inputs_args_str
,
forward_call_str
,
returns_str
,
node_creation_str
)
forward_function_declaration_str
=
f
"
{
returns_type_str
}
{
forward_function_name
}
(
{
inputs_args_str
}
);"
return
forward_function_str
,
forward_function_declaration_str
def
FakeMatmulGradAPI
():
fake_matmul_grad_str
=
"""
namespace paddle {
namespace experimental {
std::vector<std::vector<Tensor>> matmul_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
bool transpose_x,
bool transpose_y) {
std::vector<std::vector<Tensor>> ret;
return ret;
}
}
}
"""
return
fake_matmul_grad_str
def
GenerateNodeCCFile
(
filepath
,
node_definition_str
):
file_contents
=
"""
#include "glog/logging.h"
...
...
@@ -819,6 +854,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
"""
file_contents
+=
FakeMatmulGradAPI
()
file_contents
+=
node_definition_str
with
open
(
filepath
,
'a'
)
as
f
:
f
.
write
(
file_contents
)
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
0 → 100644
浏览文件 @
62b15566
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
from
eager_gen
import
ReadFwdFile
,
GetForwardFunctionName
,
ParseYamlForward
,
DetermineForwardPositionMap
atype_to_parsing_function
=
{
"bool"
:
"CastPyArg2Boolean"
,
"int"
:
"CastPyArg2Int"
,
"long"
:
"CastPyArg2Long"
,
"float"
:
"CastPyArg2Float"
,
"string"
:
"CastPyArg2String"
,
"bool[]"
:
"CastPyArg2Booleans"
,
"int[]"
:
"CastPyArg2Ints"
,
"long[]"
:
"CastPyArg2Longs"
,
"float[]"
:
"CastPyArg2Floats"
,
"double[]"
:
"CastPyArg2Float64s"
,
"string[]"
:
"CastPyArg2Strings"
}
atype_to_cxx_type
=
{
"bool"
:
"bool"
,
"int"
:
"int"
,
"long"
:
"long"
,
"float"
:
"float"
,
"string"
:
"std::string"
,
"bool[]"
:
"std::vector<bool>"
,
"int[]"
:
"std::vector<int>"
,
"long[]"
:
"std::vector<long>"
,
"float[]"
:
"std::vector<float>"
,
"double[]"
:
"std::vector<double>"
,
"string[]"
:
"std::vector<std::string>"
}
def
ParseArguments
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Eager Code Generator Args Parser'
)
parser
.
add_argument
(
'--api_yaml_path'
,
type
=
str
)
parser
.
add_argument
(
'--output_path'
,
type
=
str
)
args
=
parser
.
parse_args
()
return
args
def
GetCxxType
(
atype
):
if
atype
not
in
atype_to_cxx_type
.
keys
():
assert
False
return
atype_to_cxx_type
[
atype
]
def
FindParsingFunctionFromAttributeType
(
atype
):
if
atype
not
in
atype_to_parsing_function
.
keys
():
assert
False
return
atype_to_parsing_function
[
atype
]
def
GeneratePythonCFunction
(
fwd_api_name
,
forward_inputs_position_map
,
forward_attrs_list
,
forward_outputs_position_map
):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Get EagerTensor from args
# Get dygraph function call args
num_args
=
len
(
forward_inputs_position_map
.
keys
())
+
len
(
forward_attrs_list
)
num_input_tensors
=
len
(
forward_inputs_position_map
.
keys
())
dygraph_function_call_list
=
[
""
for
i
in
range
(
num_args
)]
get_eager_tensor_str
=
""
for
name
,
(
ttype
,
pos
)
in
forward_inputs_position_map
.
items
():
get_eager_tensor_str
+=
f
" auto&
{
name
}
= GetEagerTensorFromArgs(
\"
{
fwd_api_name
}
\"
,
\"
{
name
}
\"
, args,
{
pos
}
, false);
\n
"
dygraph_function_call_list
[
pos
]
=
f
"
{
name
}
"
parse_attributes_str
=
" paddle::framework::AttributeMap attrs;
\n
"
# Get Attributes
for
name
,
atype
,
_
,
pos
in
forward_attrs_list
:
parsing_function
=
FindParsingFunctionFromAttributeType
(
atype
)
cxx_type
=
GetCxxType
(
atype
)
key
=
f
"
{
name
}
"
parse_attributes_str
+=
f
" PyObject*
{
name
}
_obj = PyTuple_GET_ITEM(args,
{
pos
}
);
\n
"
parse_attributes_str
+=
f
"
{
cxx_type
}
{
name
}
=
{
parsing_function
}
(
{
name
}
_obj,
\"
{
fwd_api_name
}
\"
,
{
pos
}
);
\n
"
dygraph_function_call_list
[
pos
]
=
f
"
{
name
}
"
dygraph_function_call_str
=
","
.
join
(
dygraph_function_call_list
)
PYTHON_C_FUNCTION_TEMPLATE
=
"""
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
{{
PyThreadState *tstate = nullptr;
try
{{
// Get EagerTensors from args
{}
// Parse Attributes
{}
tstate = PyEval_SaveThread();
auto out = {}({});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
}}
catch(...) {{
if (tstate) {{
PyEval_RestoreThread(tstate);
}}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
python_c_function_str
=
PYTHON_C_FUNCTION_TEMPLATE
.
format
(
fwd_api_name
,
get_eager_tensor_str
,
parse_attributes_str
,
GetForwardFunctionName
(
fwd_api_name
),
dygraph_function_call_str
)
python_c_function_reg_str
=
f
"{{
\"
final_state_
{
fwd_api_name
}
\"
, (PyCFunction)(void(*)(void))eager_final_state_api_
{
fwd_api_name
}
, METH_VARARGS | METH_KEYWORDS,
\"
C++ interface function for
{
fwd_api_name
}
in dygraph.
\"
}}"
return
python_c_function_str
,
python_c_function_reg_str
def
GeneratePythonCWrappers
(
python_c_function_str
,
python_c_function_reg_str
):
PYTHON_C_WRAPPER_TEMPLATE
=
"""
#pragma once
#include "pybind11/detail/common.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h"
#include <Python.h>
namespace paddle {{
namespace pybind {{
{}
static PyMethodDef EagerFinalStateMethods[] = {{
{}
}};
}} // namespace pybind
}} // namespace paddle
"""
python_c_str
=
PYTHON_C_WRAPPER_TEMPLATE
.
format
(
python_c_function_str
,
python_c_function_reg_str
)
return
python_c_str
def
GeneratePythonCFile
(
filepath
,
python_c_str
):
with
open
(
filepath
,
'a'
)
as
f
:
f
.
write
(
python_c_str
)
if
__name__
==
"__main__"
:
args
=
ParseArguments
()
api_yaml_path
=
args
.
api_yaml_path
fwd_api_list
=
ReadFwdFile
(
api_yaml_path
)
python_c_function_list
=
[]
python_c_function_reg_list
=
[]
for
fwd_api
in
fwd_api_list
:
# We only generate Ops with grad
if
'backward'
not
in
fwd_api
.
keys
():
continue
assert
'api'
in
fwd_api
.
keys
()
assert
'args'
in
fwd_api
.
keys
()
assert
'output'
in
fwd_api
.
keys
()
assert
'backward'
in
fwd_api
.
keys
()
fwd_api_name
=
fwd_api
[
'api'
]
fwd_args_str
=
fwd_api
[
'args'
]
fwd_returns_str
=
fwd_api
[
'output'
]
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list
,
forward_attrs_list
,
forward_returns_list
=
ParseYamlForward
(
fwd_args_str
,
fwd_returns_str
)
print
(
"Parsed Original Forward Inputs List: "
,
forward_inputs_list
)
print
(
"Prased Original Forward Attrs List: "
,
forward_attrs_list
)
print
(
"Parsed Original Forward Returns List: "
,
forward_returns_list
)
forward_inputs_position_map
,
forward_outputs_position_map
=
DetermineForwardPositionMap
(
forward_inputs_list
,
forward_returns_list
)
print
(
"Generated Forward Input Position Map: "
,
forward_inputs_position_map
)
print
(
"Generated Forward Output Position Map: "
,
forward_outputs_position_map
)
python_c_function_str
,
python_c_function_reg_str
=
GeneratePythonCFunction
(
fwd_api_name
,
forward_inputs_position_map
,
forward_attrs_list
,
forward_outputs_position_map
)
python_c_function_list
.
append
(
python_c_function_str
)
python_c_function_reg_list
.
append
(
python_c_function_reg_str
)
print
(
"Generated Python-C Function: "
,
python_c_function_str
)
python_c_functions_str
=
"
\n
"
.
join
(
python_c_function_list
)
python_c_functions_reg_str
=
",
\n
"
.
join
(
python_c_function_reg_list
)
python_c_str
=
GeneratePythonCWrappers
(
python_c_functions_str
,
python_c_functions_reg_str
)
print
(
"Generated Python-C Codes: "
,
python_c_str
)
output_path
=
args
.
output_path
for
path
in
[
output_path
]:
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
GeneratePythonCFile
(
output_path
,
python_c_str
)
paddle/fluid/eager/utils.cc
浏览文件 @
62b15566
...
...
@@ -288,7 +288,9 @@ void EagerUtils::CheckAndRetainGrad(
paddle
::
experimental
::
Tensor
EagerUtils
::
SyncToPtenTensors
(
const
egr
::
EagerTensor
&
tensor
)
{
if
(
!
tensor
.
initialized
())
{
const_cast
<
EagerTensor
*>
(
&
tensor
)
->
SyncToTensor
();
}
return
*
tensor
.
Tensor
().
get
();
}
...
...
@@ -298,7 +300,9 @@ std::vector<paddle::experimental::Tensor> EagerUtils::SyncToPtenTensors(
size_t
num
=
tensors
.
size
();
res
.
reserve
(
num
);
for
(
size_t
i
=
0
;
i
<
num
;
i
++
)
{
if
(
!
tensors
[
i
].
initialized
())
{
const_cast
<
EagerTensor
*>
(
&
(
tensors
[
i
]))
->
SyncToTensor
();
}
res
.
push_back
(
*
tensors
[
i
].
Tensor
().
get
());
}
return
res
;
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
62b15566
...
...
@@ -151,7 +151,7 @@ if(WITH_PYTHON)
set
(
tmp_eager_impl_file
${
eager_impl_file
}
.tmp
)
set
(
OP_IMPL_DEPS op_function_generator
)
set
(
EAGER_OP_IMPL_DEPS eager_op_function_generator
)
set
(
EAGER_OP_IMPL_DEPS eager_op_function_generator
eager_final_state_python_c_codegen
)
if
(
WIN32
)
if
(
"
${
CMAKE_GENERATOR
}
"
STREQUAL
"Ninja"
)
...
...
@@ -275,7 +275,7 @@ if(WITH_PYTHON)
if
(
NOT ON_INFER
)
cc_library
(
paddle_eager
SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common dygraph_function dygraph_node accumulation_node global_utils utils python
)
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common
final_dygraph_function final_dygraph_node
dygraph_function dygraph_node accumulation_node global_utils utils python
)
add_dependencies
(
paddle_eager eager_codegen
)
add_dependencies
(
paddle_eager eager_op_function_generator_cmd
)
list
(
APPEND PYBIND_DEPS paddle_eager
)
...
...
paddle/fluid/pybind/eager_op_function_generator.cc
浏览文件 @
62b15566
...
...
@@ -393,6 +393,7 @@ int main(int argc, char* argv[]) {
std
::
vector
<
std
::
string
>
headers
{
"
\"
pybind11/detail/common.h
\"
"
,
"
\"
paddle/fluid/pybind/eager_final_state_op_function_impl.h
\"
"
,
"
\"
paddle/fluid/pybind/op_function_common.h
\"
"
,
"
\"
paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h
\"
"
,
...
...
@@ -441,6 +442,10 @@ int main(int argc, char* argv[]) {
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
"core.eager.ops failed!
\"
));
\n
"
<<
" }
\n\n
"
<<
" if (PyModule_AddFunctions(m.ptr(), EagerFinalStateMethods) < 0) {
\n
"
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
"core.eager.ops failed!
\"
));
\n
"
<<
" }
\n\n
"
<<
"}
\n\n
"
<<
"} // namespace pybind
\n
"
<<
"} // namespace paddle
\n
"
;
...
...
paddle/fluid/pybind/op_function_common.cc
浏览文件 @
62b15566
...
...
@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool
PyObject_CheckString
(
PyObject
*
obj
)
{
return
PyUnicode_Check
(
obj
);
}
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
bool
CastPyArg2Boolean
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
obj
==
Py_None
)
{
attrs
[
key
]
=
false
;
// To be compatible with QA integration testing. Some
return
false
;
// To be compatible with QA integration testing. Some
// test case pass in None.
}
else
if
(
obj
==
Py_True
)
{
attrs
[
key
]
=
true
;
return
true
;
}
else
if
(
obj
==
Py_False
)
{
attrs
[
key
]
=
false
;
return
false
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -118,14 +116,20 @@ void CastPyArg2AttrBoolean(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
false
;
}
void
CastPyArg2Attr
Int
(
PyObject
*
obj
,
void
CastPyArg2Attr
Boolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Boolean
(
obj
,
op_type
,
arg_pos
);
}
int
CastPyArg2Int
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
attrs
[
key
]
=
(
int
)
PyLong_AsLong
(
obj
);
// NOLINT
return
(
int
)
PyLong_AsLong
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -133,14 +137,21 @@ void CastPyArg2AttrInt(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
0
;
}
void
CastPyArg2Attr
Long
(
PyObject
*
obj
,
void
CastPyArg2Attr
Int
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Int
(
obj
,
op_type
,
arg_pos
);
}
int64_t
CastPyArg2Long
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
attrs
[
key
]
=
(
int64_t
)
PyLong_AsLong
(
obj
);
// NOLINT
return
(
int64_t
)
PyLong_AsLong
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -148,14 +159,21 @@ void CastPyArg2AttrLong(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
0
;
}
void
CastPyArg2Attr
Float
(
PyObject
*
obj
,
void
CastPyArg2Attr
Long
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Long
(
obj
,
op_type
,
arg_pos
);
}
float
CastPyArg2Float
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckFloatOrToFloat
(
&
obj
))
{
attrs
[
key
]
=
(
float
)
PyFloat_AsDouble
(
obj
);
// NOLINT
return
(
float
)
PyFloat_AsDouble
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -163,17 +181,24 @@ void CastPyArg2AttrFloat(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
0.0
;
}
void
CastPyArg2Attr
String
(
PyObject
*
obj
,
void
CastPyArg2Attr
Float
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Float
(
obj
,
op_type
,
arg_pos
);
}
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckString
(
obj
))
{
Py_ssize_t
size
;
const
char
*
data
;
data
=
PyUnicode_AsUTF8AndSize
(
obj
,
&
size
);
attrs
[
key
]
=
std
::
string
(
data
,
(
size_t
)
size
);
// NOLINT
return
std
::
string
(
data
,
(
size_t
)
size
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
""
;
}
void
CastPyArg2Attr
Booleans
(
PyObject
*
obj
,
void
CastPyArg2Attr
String
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2String
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
bool
>
CastPyArg2Booleans
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
bool
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
bool
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckBool
(
&
item
))
{
...
...
@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
bool
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckBool
(
&
item
))
{
...
...
@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2Attr
Int
s
(
PyObject
*
obj
,
void
CastPyArg2Attr
Boolean
s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Booleans
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
int
>
CastPyArg2Ints
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
int
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2Attr
Long
s
(
PyObject
*
obj
,
void
CastPyArg2Attr
Int
s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Ints
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
int64_t
>
CastPyArg2Longs
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
int64_t
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
...
@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2Attr
Float
s
(
PyObject
*
obj
,
void
CastPyArg2Attr
Long
s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Longs
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
float
>
CastPyArg2Floats
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
float
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2AttrFloat
64
s
(
PyObject
*
obj
,
void
CastPyArg2AttrFloats
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Floats
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
double
>
CastPyArg2Float64s
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
double
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
...
@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2Attr
String
s
(
PyObject
*
obj
,
void
CastPyArg2Attr
Float64
s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Float64s
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
std
::
string
>
CastPyArg2Strings
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
std
::
string
>
value
;
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
std
::
string
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckString
(
item
))
{
...
...
@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
std
::
vector
<
std
::
string
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckString
(
item
))
{
...
...
@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj,
i
));
}
}
attrs
[
key
]
=
value
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
...
...
@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
value
;
}
void
CastPyArg2AttrStrings
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Strings
(
obj
,
op_type
,
arg_pos
);
}
void
CastPyArg2AttrBlock
(
PyObject
*
obj
,
...
...
paddle/fluid/pybind/op_function_common.h
浏览文件 @
62b15566
...
...
@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj);
bool
PyObject_CheckString
(
PyObject
*
obj
);
bool
CastPyArg2Boolean
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
int
CastPyArg2Int
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
int64_t
CastPyArg2Long
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
float
CastPyArg2Float
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
bool
>
CastPyArg2Booleans
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
int
>
CastPyArg2Ints
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
int64_t
>
CastPyArg2Longs
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
float
>
CastPyArg2Floats
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
double
>
CastPyArg2Float64s
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
std
::
string
>
CastPyArg2Strings
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录