Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
62b15566
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
62b15566
编写于
1月 26, 2022
作者:
J
jim19930609
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added python-c code generation for final state Eager Dygraph
上级
ca743508
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
473 addition
and
112 deletion
+473
-112
paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt
...er/api/generated/eager_generated/backwards/CMakeLists.txt
+2
-2
paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt
...ger/api/generated/eager_generated/forwards/CMakeLists.txt
+2
-2
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
+1
-1
paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt
.../auto_code_generator/final_state_generator/CMakeLists.txt
+10
-0
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+56
-20
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+230
-0
paddle/fluid/eager/utils.cc
paddle/fluid/eager/utils.cc
+6
-2
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+2
-2
paddle/fluid/pybind/eager_op_function_generator.cc
paddle/fluid/pybind/eager_op_function_generator.cc
+5
-0
paddle/fluid/pybind/op_function_common.cc
paddle/fluid/pybind/op_function_common.cc
+135
-83
paddle/fluid/pybind/op_function_common.h
paddle/fluid/pybind/op_function_common.h
+24
-0
未找到文件。
paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt
浏览文件 @
62b15566
cc_library
(
scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info
)
cc_library
(
scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info
)
#
cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps})
cc_library
(
final_dygraph_node SRCS nodes.cc DEPS
${
eager_deps
}
)
#
add_dependencies(final_dygraph_node eager_final_state_codegen)
add_dependencies
(
final_dygraph_node eager_final_state_codegen
)
paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt
浏览文件 @
62b15566
cc_library
(
eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node
)
cc_library
(
eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node
)
#
cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps})
cc_library
(
final_dygraph_function SRCS dygraph_functions.cc DEPS
${
eager_deps
}
)
#
add_dependencies(final_dygraph_function eager_final_state_codegen)
add_dependencies
(
final_dygraph_function eager_final_state_codegen
)
paddle/fluid/eager/auto_code_generator/CMakeLists.txt
浏览文件 @
62b15566
#
add_subdirectory(final_state_generator)
add_subdirectory
(
final_state_generator
)
set
(
EAGER_GENERETOR_DEPS
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag
)
set
(
EAGER_GENERETOR_DEPS
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag
)
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt
浏览文件 @
62b15566
...
@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen
...
@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_nodes_h_path
}
${
nodes_h_path
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_nodes_h_path
}
${
nodes_h_path
}
VERBATIM
VERBATIM
)
)
set
(
tmp_python_c_output_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h"
)
set
(
python_c_output_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/pybind/eager_final_state_op_function_impl.h"
)
add_custom_target
(
eager_final_state_python_c_codegen
COMMAND
"
${
PYTHON_EXECUTABLE
}
"
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py"
"--api_yaml_path=
${
api_yaml_path
}
"
"--output_path=
${
tmp_python_c_output_path
}
"
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_python_c_output_path
}
${
python_c_output_path
}
VERBATIM
)
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
62b15566
...
@@ -82,6 +82,14 @@ def RemoveConstAndReference(string):
...
@@ -82,6 +82,14 @@ def RemoveConstAndReference(string):
return
ret
return
ret
def
GetGradNodeName
(
string
):
return
f
"FinalGradNode
{
string
}
"
def
GetForwardFunctionName
(
string
):
return
f
"
{
string
}
_final_state_dygraph_function"
def
GetAutoGradMetaName
(
string
):
def
GetAutoGradMetaName
(
string
):
return
f
"
{
string
}
_autograd_meta"
return
f
"
{
string
}
_autograd_meta"
...
@@ -145,13 +153,13 @@ def ParseYamlArgs(string):
...
@@ -145,13 +153,13 @@ def ParseYamlArgs(string):
def
ParseYamlReturns
(
string
):
def
ParseYamlReturns
(
string
):
# Example: Tensor, Tensor
# Example: Tensor, Tensor
# list = [ [ret_type, orig_position], ...]
# list = [ [
"",
ret_type, orig_position], ...]
returns_list
=
[]
returns_list
=
[]
returns
=
[
x
.
strip
()
for
x
in
string
.
strip
().
split
(
","
)]
returns
=
[
x
.
strip
()
for
x
in
string
.
strip
().
split
(
","
)]
for
i
in
range
(
len
(
returns
)):
for
i
in
range
(
len
(
returns
)):
ret
=
returns
[
i
]
ret
=
returns
[
i
]
returns_list
.
append
([
ret
,
i
])
returns_list
.
append
([
""
,
ret
,
i
])
return
returns_list
return
returns_list
...
@@ -260,8 +268,8 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
...
@@ -260,8 +268,8 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
assert
orig_attr_pos
==
forward_attr_pos
assert
orig_attr_pos
==
forward_attr_pos
for
i
in
range
(
len
(
forward_returns_list
)):
for
i
in
range
(
len
(
forward_returns_list
)):
orig_return_type
=
orig_forward_returns_list
[
i
][
0
]
orig_return_type
=
orig_forward_returns_list
[
i
][
1
]
orig_return_pos
=
orig_forward_returns_list
[
i
][
1
]
orig_return_pos
=
orig_forward_returns_list
[
i
][
2
]
forward_return_type
=
forward_returns_list
[
i
][
1
]
forward_return_type
=
forward_returns_list
[
i
][
1
]
forward_return_pos
=
forward_returns_list
[
i
][
2
]
forward_return_pos
=
forward_returns_list
[
i
][
2
]
...
@@ -452,13 +460,14 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
...
@@ -452,13 +460,14 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
RemoveConstAndReference
(
atype
),
saved_attr_name
,
default_val
)
RemoveConstAndReference
(
atype
),
saved_attr_name
,
default_val
)
# End: SetAttributes & Attribute Members
# End: SetAttributes & Attribute Members
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
NODE_DECLARATION_TEMPLATE
=
"""
NODE_DECLARATION_TEMPLATE
=
"""
class
GradNode
{} : public egr::GradNodeBase {{
class {} : public egr::GradNodeBase {{
public:
public:
GradNode
{}() : egr::GradNodeBase() {{}}
{}() : egr::GradNodeBase() {{}}
GradNode
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~
GradNode
{}() override = default;
~{}() override = default;
virtual std::vector<std::vector<egr::EagerTensor>> operator()(
virtual std::vector<std::vector<egr::EagerTensor>> operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) override;
const std::vector<std::vector<egr::EagerTensor>>& grads) override;
...
@@ -476,7 +485,7 @@ class GradNode{} : public egr::GradNodeBase {{
...
@@ -476,7 +485,7 @@ class GradNode{} : public egr::GradNodeBase {{
}};
}};
"""
"""
node_declaration_str
=
NODE_DECLARATION_TEMPLATE
.
format
(
node_declaration_str
=
NODE_DECLARATION_TEMPLATE
.
format
(
forward_op_name
,
forward_op_name
,
forward_op_name
,
forward_op
_name
,
grad_node_name
,
grad_node_name
,
grad_node_name
,
grad_node
_name
,
set_tensor_wrapper_methods_str
,
set_attribute_methods_str
,
set_tensor_wrapper_methods_str
,
set_attribute_methods_str
,
tensor_wrapper_members_str
,
attribute_members_str
)
tensor_wrapper_members_str
,
attribute_members_str
)
...
@@ -503,10 +512,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
...
@@ -503,10 +512,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_api_args
[
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->
{
tensor_wrapper_name
}
, nullptr) )"
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->
{
tensor_wrapper_name
}
, nullptr) )"
for
_
,
(
_
,
fwd_position
,
for
_
,
(
ttype
,
fwd_position
,
grad_api_position
)
in
backward_grad_input_map
.
items
():
grad_api_position
)
in
backward_grad_input_map
.
items
():
grad_api_args
[
if
IsPlainTensorType
(
ttype
):
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( grads[
{
fwd_position
}
] )"
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( grads[
{
fwd_position
}
][0] )"
else
:
assert
IsVectorTensorType
(
ttype
)
grad_api_args
[
grad_api_position
]
=
f
"egr::EagerUtils::SyncToPtenTensors( grads[
{
fwd_position
}
] )"
for
name
,
_
,
_
,
grad_api_position
in
backward_attrs_list
:
for
name
,
_
,
_
,
grad_api_position
in
backward_attrs_list
:
saved_attribute_name
=
GetSavedName
(
name
)
saved_attribute_name
=
GetSavedName
(
name
)
...
@@ -531,8 +545,9 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
...
@@ -531,8 +545,9 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
returns_str
+=
f
"returns[
{
fwd_position
}
] = egr::EagerUtils::CreateEagerTensorFromTensor( grad_api_returns[
{
grad_api_position
}
] );
\n
"
returns_str
+=
f
"returns[
{
fwd_position
}
] = egr::EagerUtils::CreateEagerTensorFromTensor( grad_api_returns[
{
grad_api_position
}
] );
\n
"
returns_str
+=
f
"return returns;
\n
"
returns_str
+=
f
"return returns;
\n
"
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
FUNCTION_TEMPLATE
=
"""
FUNCTION_TEMPLATE
=
"""
std::vector<std::vector<egr::EagerTensor>>
GradNode
{}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
std::vector<std::vector<egr::EagerTensor>> {}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
// Call grad_api function
// Call grad_api function
auto grad_api_returns = paddle::experimental::{}({});
auto grad_api_returns = paddle::experimental::{}({});
{}
{}
...
@@ -540,7 +555,7 @@ std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vec
...
@@ -540,7 +555,7 @@ std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vec
"""
"""
node_definition_str
=
FUNCTION_TEMPLATE
.
format
(
node_definition_str
=
FUNCTION_TEMPLATE
.
format
(
fwd_api
_name
,
bwd_api_name
,
grad_api_args_str
,
returns_str
)
grad_node
_name
,
bwd_api_name
,
grad_api_args_str
,
returns_str
)
return
node_definition_str
return
node_definition_str
...
@@ -610,7 +625,8 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
...
@@ -610,7 +625,8 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# Node Construction
# Node Construction
num_bwd_inputs
=
len
(
backward_grad_input_map
.
keys
())
num_bwd_inputs
=
len
(
backward_grad_input_map
.
keys
())
num_bwd_outputs
=
len
(
backward_grad_output_map
.
keys
())
num_bwd_outputs
=
len
(
backward_grad_output_map
.
keys
())
node_construction_str
=
f
" auto grad_node = std::make_shared<GradNode
{
fwd_api_name
}
>(
{
num_bwd_inputs
}
,
{
num_bwd_outputs
}
);"
grad_node_name
=
GetGradNodeName
(
fwd_api_name
)
node_construction_str
=
f
" auto grad_node = std::make_shared<
{
grad_node_name
}
>(
{
num_bwd_inputs
}
,
{
num_bwd_outputs
}
);"
# SetAttributes
# SetAttributes
set_attributes_list
=
[]
set_attributes_list
=
[]
...
@@ -786,7 +802,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
...
@@ -786,7 +802,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
backward_grad_output_map
,
backward_attrs_list
)
backward_grad_output_map
,
backward_attrs_list
)
FORWARD_FUNCTION_TEMPLATE
=
"""
FORWARD_FUNCTION_TEMPLATE
=
"""
{} {}
_dygraph_function
({}) {{
{} {}({}) {{
// Forward API Call
// Forward API Call
{}
{}
...
@@ -799,15 +815,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
...
@@ -799,15 +815,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
}}
}}
"""
"""
forward_function_name
=
GetForwardFunctionName
(
fwd_api_name
)
forward_function_str
=
FORWARD_FUNCTION_TEMPLATE
.
format
(
forward_function_str
=
FORWARD_FUNCTION_TEMPLATE
.
format
(
returns_type_str
,
fwd_api_name
,
inputs_args_str
,
forward_call_str
,
returns_type_str
,
forward_function_name
,
inputs_args_str
,
returns_str
,
node_creation_str
)
forward_call_str
,
returns_str
,
node_creation_str
)
forward_function_declaration_str
=
f
"
{
returns_type_str
}
{
forward_function_name
}
(
{
inputs_args_str
}
);"
forward_function_declaration_str
=
f
"
{
returns_type_str
}
{
fwd_api_name
}
_dygraph_function(
{
inputs_args_str
}
);"
return
forward_function_str
,
forward_function_declaration_str
return
forward_function_str
,
forward_function_declaration_str
def
FakeMatmulGradAPI
():
fake_matmul_grad_str
=
"""
namespace paddle {
namespace experimental {
std::vector<std::vector<Tensor>> matmul_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
bool transpose_x,
bool transpose_y) {
std::vector<std::vector<Tensor>> ret;
return ret;
}
}
}
"""
return
fake_matmul_grad_str
def
GenerateNodeCCFile
(
filepath
,
node_definition_str
):
def
GenerateNodeCCFile
(
filepath
,
node_definition_str
):
file_contents
=
"""
file_contents
=
"""
#include "glog/logging.h"
#include "glog/logging.h"
...
@@ -819,6 +854,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
...
@@ -819,6 +854,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
"""
"""
file_contents
+=
FakeMatmulGradAPI
()
file_contents
+=
node_definition_str
file_contents
+=
node_definition_str
with
open
(
filepath
,
'a'
)
as
f
:
with
open
(
filepath
,
'a'
)
as
f
:
f
.
write
(
file_contents
)
f
.
write
(
file_contents
)
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
0 → 100644
浏览文件 @
62b15566
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
from
eager_gen
import
ReadFwdFile
,
GetForwardFunctionName
,
ParseYamlForward
,
DetermineForwardPositionMap
atype_to_parsing_function
=
{
"bool"
:
"CastPyArg2Boolean"
,
"int"
:
"CastPyArg2Int"
,
"long"
:
"CastPyArg2Long"
,
"float"
:
"CastPyArg2Float"
,
"string"
:
"CastPyArg2String"
,
"bool[]"
:
"CastPyArg2Booleans"
,
"int[]"
:
"CastPyArg2Ints"
,
"long[]"
:
"CastPyArg2Longs"
,
"float[]"
:
"CastPyArg2Floats"
,
"double[]"
:
"CastPyArg2Float64s"
,
"string[]"
:
"CastPyArg2Strings"
}
atype_to_cxx_type
=
{
"bool"
:
"bool"
,
"int"
:
"int"
,
"long"
:
"long"
,
"float"
:
"float"
,
"string"
:
"std::string"
,
"bool[]"
:
"std::vector<bool>"
,
"int[]"
:
"std::vector<int>"
,
"long[]"
:
"std::vector<long>"
,
"float[]"
:
"std::vector<float>"
,
"double[]"
:
"std::vector<double>"
,
"string[]"
:
"std::vector<std::string>"
}
def
ParseArguments
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Eager Code Generator Args Parser'
)
parser
.
add_argument
(
'--api_yaml_path'
,
type
=
str
)
parser
.
add_argument
(
'--output_path'
,
type
=
str
)
args
=
parser
.
parse_args
()
return
args
def
GetCxxType
(
atype
):
if
atype
not
in
atype_to_cxx_type
.
keys
():
assert
False
return
atype_to_cxx_type
[
atype
]
def
FindParsingFunctionFromAttributeType
(
atype
):
if
atype
not
in
atype_to_parsing_function
.
keys
():
assert
False
return
atype_to_parsing_function
[
atype
]
def
GeneratePythonCFunction
(
fwd_api_name
,
forward_inputs_position_map
,
forward_attrs_list
,
forward_outputs_position_map
):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Get EagerTensor from args
# Get dygraph function call args
num_args
=
len
(
forward_inputs_position_map
.
keys
())
+
len
(
forward_attrs_list
)
num_input_tensors
=
len
(
forward_inputs_position_map
.
keys
())
dygraph_function_call_list
=
[
""
for
i
in
range
(
num_args
)]
get_eager_tensor_str
=
""
for
name
,
(
ttype
,
pos
)
in
forward_inputs_position_map
.
items
():
get_eager_tensor_str
+=
f
" auto&
{
name
}
= GetEagerTensorFromArgs(
\"
{
fwd_api_name
}
\"
,
\"
{
name
}
\"
, args,
{
pos
}
, false);
\n
"
dygraph_function_call_list
[
pos
]
=
f
"
{
name
}
"
parse_attributes_str
=
" paddle::framework::AttributeMap attrs;
\n
"
# Get Attributes
for
name
,
atype
,
_
,
pos
in
forward_attrs_list
:
parsing_function
=
FindParsingFunctionFromAttributeType
(
atype
)
cxx_type
=
GetCxxType
(
atype
)
key
=
f
"
{
name
}
"
parse_attributes_str
+=
f
" PyObject*
{
name
}
_obj = PyTuple_GET_ITEM(args,
{
pos
}
);
\n
"
parse_attributes_str
+=
f
"
{
cxx_type
}
{
name
}
=
{
parsing_function
}
(
{
name
}
_obj,
\"
{
fwd_api_name
}
\"
,
{
pos
}
);
\n
"
dygraph_function_call_list
[
pos
]
=
f
"
{
name
}
"
dygraph_function_call_str
=
","
.
join
(
dygraph_function_call_list
)
PYTHON_C_FUNCTION_TEMPLATE
=
"""
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
{{
PyThreadState *tstate = nullptr;
try
{{
// Get EagerTensors from args
{}
// Parse Attributes
{}
tstate = PyEval_SaveThread();
auto out = {}({});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
}}
catch(...) {{
if (tstate) {{
PyEval_RestoreThread(tstate);
}}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
python_c_function_str
=
PYTHON_C_FUNCTION_TEMPLATE
.
format
(
fwd_api_name
,
get_eager_tensor_str
,
parse_attributes_str
,
GetForwardFunctionName
(
fwd_api_name
),
dygraph_function_call_str
)
python_c_function_reg_str
=
f
"{{
\"
final_state_
{
fwd_api_name
}
\"
, (PyCFunction)(void(*)(void))eager_final_state_api_
{
fwd_api_name
}
, METH_VARARGS | METH_KEYWORDS,
\"
C++ interface function for
{
fwd_api_name
}
in dygraph.
\"
}}"
return
python_c_function_str
,
python_c_function_reg_str
def
GeneratePythonCWrappers
(
python_c_function_str
,
python_c_function_reg_str
):
PYTHON_C_WRAPPER_TEMPLATE
=
"""
#pragma once
#include "pybind11/detail/common.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h"
#include <Python.h>
namespace paddle {{
namespace pybind {{
{}
static PyMethodDef EagerFinalStateMethods[] = {{
{}
}};
}} // namespace pybind
}} // namespace paddle
"""
python_c_str
=
PYTHON_C_WRAPPER_TEMPLATE
.
format
(
python_c_function_str
,
python_c_function_reg_str
)
return
python_c_str
def
GeneratePythonCFile
(
filepath
,
python_c_str
):
with
open
(
filepath
,
'a'
)
as
f
:
f
.
write
(
python_c_str
)
if
__name__
==
"__main__"
:
args
=
ParseArguments
()
api_yaml_path
=
args
.
api_yaml_path
fwd_api_list
=
ReadFwdFile
(
api_yaml_path
)
python_c_function_list
=
[]
python_c_function_reg_list
=
[]
for
fwd_api
in
fwd_api_list
:
# We only generate Ops with grad
if
'backward'
not
in
fwd_api
.
keys
():
continue
assert
'api'
in
fwd_api
.
keys
()
assert
'args'
in
fwd_api
.
keys
()
assert
'output'
in
fwd_api
.
keys
()
assert
'backward'
in
fwd_api
.
keys
()
fwd_api_name
=
fwd_api
[
'api'
]
fwd_args_str
=
fwd_api
[
'args'
]
fwd_returns_str
=
fwd_api
[
'output'
]
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list
,
forward_attrs_list
,
forward_returns_list
=
ParseYamlForward
(
fwd_args_str
,
fwd_returns_str
)
print
(
"Parsed Original Forward Inputs List: "
,
forward_inputs_list
)
print
(
"Prased Original Forward Attrs List: "
,
forward_attrs_list
)
print
(
"Parsed Original Forward Returns List: "
,
forward_returns_list
)
forward_inputs_position_map
,
forward_outputs_position_map
=
DetermineForwardPositionMap
(
forward_inputs_list
,
forward_returns_list
)
print
(
"Generated Forward Input Position Map: "
,
forward_inputs_position_map
)
print
(
"Generated Forward Output Position Map: "
,
forward_outputs_position_map
)
python_c_function_str
,
python_c_function_reg_str
=
GeneratePythonCFunction
(
fwd_api_name
,
forward_inputs_position_map
,
forward_attrs_list
,
forward_outputs_position_map
)
python_c_function_list
.
append
(
python_c_function_str
)
python_c_function_reg_list
.
append
(
python_c_function_reg_str
)
print
(
"Generated Python-C Function: "
,
python_c_function_str
)
python_c_functions_str
=
"
\n
"
.
join
(
python_c_function_list
)
python_c_functions_reg_str
=
",
\n
"
.
join
(
python_c_function_reg_list
)
python_c_str
=
GeneratePythonCWrappers
(
python_c_functions_str
,
python_c_functions_reg_str
)
print
(
"Generated Python-C Codes: "
,
python_c_str
)
output_path
=
args
.
output_path
for
path
in
[
output_path
]:
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
GeneratePythonCFile
(
output_path
,
python_c_str
)
paddle/fluid/eager/utils.cc
浏览文件 @
62b15566
...
@@ -288,7 +288,9 @@ void EagerUtils::CheckAndRetainGrad(
...
@@ -288,7 +288,9 @@ void EagerUtils::CheckAndRetainGrad(
paddle
::
experimental
::
Tensor
EagerUtils
::
SyncToPtenTensors
(
paddle
::
experimental
::
Tensor
EagerUtils
::
SyncToPtenTensors
(
const
egr
::
EagerTensor
&
tensor
)
{
const
egr
::
EagerTensor
&
tensor
)
{
const_cast
<
EagerTensor
*>
(
&
tensor
)
->
SyncToTensor
();
if
(
!
tensor
.
initialized
())
{
const_cast
<
EagerTensor
*>
(
&
tensor
)
->
SyncToTensor
();
}
return
*
tensor
.
Tensor
().
get
();
return
*
tensor
.
Tensor
().
get
();
}
}
...
@@ -298,7 +300,9 @@ std::vector<paddle::experimental::Tensor> EagerUtils::SyncToPtenTensors(
...
@@ -298,7 +300,9 @@ std::vector<paddle::experimental::Tensor> EagerUtils::SyncToPtenTensors(
size_t
num
=
tensors
.
size
();
size_t
num
=
tensors
.
size
();
res
.
reserve
(
num
);
res
.
reserve
(
num
);
for
(
size_t
i
=
0
;
i
<
num
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
num
;
i
++
)
{
const_cast
<
EagerTensor
*>
(
&
(
tensors
[
i
]))
->
SyncToTensor
();
if
(
!
tensors
[
i
].
initialized
())
{
const_cast
<
EagerTensor
*>
(
&
(
tensors
[
i
]))
->
SyncToTensor
();
}
res
.
push_back
(
*
tensors
[
i
].
Tensor
().
get
());
res
.
push_back
(
*
tensors
[
i
].
Tensor
().
get
());
}
}
return
res
;
return
res
;
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
62b15566
...
@@ -151,7 +151,7 @@ if(WITH_PYTHON)
...
@@ -151,7 +151,7 @@ if(WITH_PYTHON)
set
(
tmp_eager_impl_file
${
eager_impl_file
}
.tmp
)
set
(
tmp_eager_impl_file
${
eager_impl_file
}
.tmp
)
set
(
OP_IMPL_DEPS op_function_generator
)
set
(
OP_IMPL_DEPS op_function_generator
)
set
(
EAGER_OP_IMPL_DEPS eager_op_function_generator
)
set
(
EAGER_OP_IMPL_DEPS eager_op_function_generator
eager_final_state_python_c_codegen
)
if
(
WIN32
)
if
(
WIN32
)
if
(
"
${
CMAKE_GENERATOR
}
"
STREQUAL
"Ninja"
)
if
(
"
${
CMAKE_GENERATOR
}
"
STREQUAL
"Ninja"
)
...
@@ -275,7 +275,7 @@ if(WITH_PYTHON)
...
@@ -275,7 +275,7 @@ if(WITH_PYTHON)
if
(
NOT ON_INFER
)
if
(
NOT ON_INFER
)
cc_library
(
paddle_eager
cc_library
(
paddle_eager
SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc
SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common dygraph_function dygraph_node accumulation_node global_utils utils python
)
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common
final_dygraph_function final_dygraph_node
dygraph_function dygraph_node accumulation_node global_utils utils python
)
add_dependencies
(
paddle_eager eager_codegen
)
add_dependencies
(
paddle_eager eager_codegen
)
add_dependencies
(
paddle_eager eager_op_function_generator_cmd
)
add_dependencies
(
paddle_eager eager_op_function_generator_cmd
)
list
(
APPEND PYBIND_DEPS paddle_eager
)
list
(
APPEND PYBIND_DEPS paddle_eager
)
...
...
paddle/fluid/pybind/eager_op_function_generator.cc
浏览文件 @
62b15566
...
@@ -393,6 +393,7 @@ int main(int argc, char* argv[]) {
...
@@ -393,6 +393,7 @@ int main(int argc, char* argv[]) {
std
::
vector
<
std
::
string
>
headers
{
std
::
vector
<
std
::
string
>
headers
{
"
\"
pybind11/detail/common.h
\"
"
,
"
\"
pybind11/detail/common.h
\"
"
,
"
\"
paddle/fluid/pybind/eager_final_state_op_function_impl.h
\"
"
,
"
\"
paddle/fluid/pybind/op_function_common.h
\"
"
,
"
\"
paddle/fluid/pybind/op_function_common.h
\"
"
,
"
\"
paddle/fluid/eager/api/generated/fluid_generated/"
"
\"
paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h
\"
"
,
"dygraph_forward_api.h
\"
"
,
...
@@ -441,6 +442,10 @@ int main(int argc, char* argv[]) {
...
@@ -441,6 +442,10 @@ int main(int argc, char* argv[]) {
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
"core.eager.ops failed!
\"
));
\n
"
"core.eager.ops failed!
\"
));
\n
"
<<
" }
\n\n
"
<<
" }
\n\n
"
<<
" if (PyModule_AddFunctions(m.ptr(), EagerFinalStateMethods) < 0) {
\n
"
<<
" PADDLE_THROW(platform::errors::Fatal (
\"
Add functions to "
"core.eager.ops failed!
\"
));
\n
"
<<
" }
\n\n
"
<<
"}
\n\n
"
<<
"}
\n\n
"
<<
"} // namespace pybind
\n
"
<<
"} // namespace pybind
\n
"
<<
"} // namespace paddle
\n
"
;
<<
"} // namespace paddle
\n
"
;
...
...
paddle/fluid/pybind/op_function_common.cc
浏览文件 @
62b15566
...
@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
...
@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool
PyObject_CheckString
(
PyObject
*
obj
)
{
return
PyUnicode_Check
(
obj
);
}
bool
PyObject_CheckString
(
PyObject
*
obj
)
{
return
PyUnicode_Check
(
obj
);
}
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
bool
CastPyArg2Boolean
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
ssize_t
arg_pos
)
{
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
obj
==
Py_None
)
{
if
(
obj
==
Py_None
)
{
attrs
[
key
]
=
false
;
// To be compatible with QA integration testing. Some
return
false
;
// To be compatible with QA integration testing. Some
// test case pass in None.
// test case pass in None.
}
else
if
(
obj
==
Py_True
)
{
}
else
if
(
obj
==
Py_True
)
{
attrs
[
key
]
=
true
;
return
true
;
}
else
if
(
obj
==
Py_False
)
{
}
else
if
(
obj
==
Py_False
)
{
attrs
[
key
]
=
false
;
return
false
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
...
@@ -118,62 +116,89 @@ void CastPyArg2AttrBoolean(PyObject* obj,
...
@@ -118,62 +116,89 @@ void CastPyArg2AttrBoolean(PyObject* obj,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
false
;
}
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Boolean
(
obj
,
op_type
,
arg_pos
);
}
int
CastPyArg2Int
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
return
(
int
)
PyLong_AsLong
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"int, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
0
;
}
}
void
CastPyArg2AttrInt
(
PyObject
*
obj
,
void
CastPyArg2AttrInt
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Int
(
obj
,
op_type
,
arg_pos
);
}
int64_t
CastPyArg2Long
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
attrs
[
key
]
=
(
in
t
)
PyLong_AsLong
(
obj
);
// NOLINT
return
(
int64_
t
)
PyLong_AsLong
(
obj
);
// NOLINT
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
"
int
, but got %s"
,
"
long
, but got %s"
,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
0
;
}
}
void
CastPyArg2AttrLong
(
PyObject
*
obj
,
void
CastPyArg2AttrLong
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
if
(
PyObject_CheckLongOrToLong
(
&
obj
))
{
attrs
[
key
]
=
CastPyArg2Long
(
obj
,
op_type
,
arg_pos
);
attrs
[
key
]
=
(
int64_t
)
PyLong_AsLong
(
obj
);
// NOLINT
}
float
CastPyArg2Float
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckFloatOrToFloat
(
&
obj
))
{
return
(
float
)
PyFloat_AsDouble
(
obj
);
// NOLINT
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
"
long
, but got %s"
,
"
float
, but got %s"
,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
0.0
;
}
}
void
CastPyArg2AttrFloat
(
PyObject
*
obj
,
void
CastPyArg2AttrFloat
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
if
(
PyObject_CheckFloatOrToFloat
(
&
obj
))
{
attrs
[
key
]
=
CastPyArg2Float
(
obj
,
op_type
,
arg_pos
);
attrs
[
key
]
=
(
float
)
PyFloat_AsDouble
(
obj
);
// NOLINT
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"float, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
}
void
CastPyArg2AttrString
(
PyObject
*
obj
,
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
ssize_t
arg_pos
)
{
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
PyObject_CheckString
(
obj
))
{
if
(
PyObject_CheckString
(
obj
))
{
Py_ssize_t
size
;
Py_ssize_t
size
;
const
char
*
data
;
const
char
*
data
;
data
=
PyUnicode_AsUTF8AndSize
(
obj
,
&
size
);
data
=
PyUnicode_AsUTF8AndSize
(
obj
,
&
size
);
attrs
[
key
]
=
std
::
string
(
data
,
(
size_t
)
size
);
// NOLINT
return
std
::
string
(
data
,
(
size_t
)
size
);
// NOLINT
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
...
@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj,
...
@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
""
;
}
}
void
CastPyArg2AttrBooleans
(
PyObject
*
obj
,
void
CastPyArg2AttrString
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2String
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
bool
>
CastPyArg2Booleans
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
bool
>
value
;
if
(
PyList_Check
(
obj
))
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
bool
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckBool
(
&
item
))
{
if
(
PyObject_CheckBool
(
&
item
))
{
...
@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj,
...
@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
bool
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckBool
(
&
item
))
{
if
(
PyObject_CheckBool
(
&
item
))
{
...
@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj,
...
@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
...
@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj,
...
@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
value
;
}
}
void
CastPyArg2AttrInts
(
PyObject
*
obj
,
void
CastPyArg2AttrBooleans
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Booleans
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
int
>
CastPyArg2Ints
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
int
>
value
;
if
(
PyList_Check
(
obj
))
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj,
...
@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj,
...
@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
int
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj,
...
@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
...
@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj,
...
@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
value
;
}
}
void
CastPyArg2AttrLongs
(
PyObject
*
obj
,
void
CastPyArg2AttrInts
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Ints
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
int64_t
>
CastPyArg2Longs
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
int64_t
>
value
;
if
(
PyList_Check
(
obj
))
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
...
@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
...
@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
int64_t
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
if
(
PyObject_CheckLongOrToLong
(
&
item
))
{
...
@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj,
...
@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
...
@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj,
...
@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
value
;
}
}
void
CastPyArg2AttrFloats
(
PyObject
*
obj
,
void
CastPyArg2AttrLongs
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Longs
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
float
>
CastPyArg2Floats
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
float
>
value
;
if
(
PyList_Check
(
obj
))
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
...
@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
...
@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
float
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj,
...
@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
...
@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj,
...
@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
value
;
}
}
void
CastPyArg2AttrFloat64s
(
PyObject
*
obj
,
void
CastPyArg2AttrFloats
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Floats
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
double
>
CastPyArg2Float64s
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
double
>
value
;
if
(
PyList_Check
(
obj
))
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
...
@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
...
@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PySequence_Check
(
obj
))
{
}
else
if
(
PySequence_Check
(
obj
))
{
Py_ssize_t
len
=
PySequence_Size
(
obj
);
Py_ssize_t
len
=
PySequence_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
double
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PySequence_GetItem
(
obj
,
i
);
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
if
(
PyObject_CheckFloatOrToFloat
(
&
item
))
{
...
@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
...
@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
...
@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
...
@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
value
;
}
}
void
CastPyArg2AttrStrings
(
PyObject
*
obj
,
void
CastPyArg2AttrFloat64s
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Float64s
(
obj
,
op_type
,
arg_pos
);
}
std
::
vector
<
std
::
string
>
CastPyArg2Strings
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
std
::
vector
<
std
::
string
>
value
;
if
(
PyList_Check
(
obj
))
{
if
(
PyList_Check
(
obj
))
{
Py_ssize_t
len
=
PyList_Size
(
obj
);
Py_ssize_t
len
=
PyList_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
std
::
string
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyList_GetItem
(
obj
,
i
);
item
=
PyList_GetItem
(
obj
,
i
);
if
(
PyObject_CheckString
(
item
))
{
if
(
PyObject_CheckString
(
item
))
{
...
@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj,
...
@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
if
(
PyTuple_Check
(
obj
))
{
}
else
if
(
PyTuple_Check
(
obj
))
{
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
Py_ssize_t
len
=
PyTuple_Size
(
obj
);
PyObject
*
item
=
nullptr
;
PyObject
*
item
=
nullptr
;
std
::
vector
<
std
::
string
>
value
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
PyTuple_GetItem
(
obj
,
i
);
item
=
PyTuple_GetItem
(
obj
,
i
);
if
(
PyObject_CheckString
(
item
))
{
if
(
PyObject_CheckString
(
item
))
{
...
@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj,
...
@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj,
i
));
i
));
}
}
}
}
attrs
[
key
]
=
value
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"%s(): argument (position %d) must be "
...
@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj,
...
@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj,
op_type
,
arg_pos
+
1
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
}
return
value
;
}
void
CastPyArg2AttrStrings
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
attrs
[
key
]
=
CastPyArg2Strings
(
obj
,
op_type
,
arg_pos
);
}
}
void
CastPyArg2AttrBlock
(
PyObject
*
obj
,
void
CastPyArg2AttrBlock
(
PyObject
*
obj
,
...
...
paddle/fluid/pybind/op_function_common.h
浏览文件 @
62b15566
...
@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj);
...
@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj);
bool
PyObject_CheckString
(
PyObject
*
obj
);
bool
PyObject_CheckString
(
PyObject
*
obj
);
bool
CastPyArg2Boolean
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
int
CastPyArg2Int
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
int64_t
CastPyArg2Long
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
float
CastPyArg2Float
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
string
CastPyArg2String
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
bool
>
CastPyArg2Booleans
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
int
>
CastPyArg2Ints
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
int64_t
>
CastPyArg2Longs
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
float
>
CastPyArg2Floats
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
double
>
CastPyArg2Float64s
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
std
::
vector
<
std
::
string
>
CastPyArg2Strings
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
void
CastPyArg2AttrBoolean
(
PyObject
*
obj
,
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
paddle
::
framework
::
AttributeMap
&
attrs
,
// NOLINT
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
const
std
::
string
&
key
,
const
std
::
string
&
op_type
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录