Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
638aab6e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
638aab6e
编写于
2月 18, 2022
作者:
Z
zyfncg
提交者:
GitHub
2月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Pten] Support inplace and intermediate in C++ API (#39651)
* support inplace and intermediate in yaml * add cmake for dygraph_api
上级
70b9f2ac
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
208 addition
and
68 deletion
+208
-68
.gitignore
.gitignore
+4
-3
paddle/pten/api/lib/CMakeLists.txt
paddle/pten/api/lib/CMakeLists.txt
+10
-1
paddle/pten/api/lib/api_utils.h
paddle/pten/api/lib/api_utils.h
+14
-8
paddle/pten/api/lib/tensor.cc
paddle/pten/api/lib/tensor.cc
+7
-4
paddle/pten/tests/api/test_reshape_api.cc
paddle/pten/tests/api/test_reshape_api.cc
+19
-0
paddle/pten/tests/api/test_scale_api.cc
paddle/pten/tests/api/test_scale_api.cc
+1
-1
python/paddle/utils/code_gen/api.yaml
python/paddle/utils/code_gen/api.yaml
+2
-1
python/paddle/utils/code_gen/api_base.py
python/paddle/utils/code_gen/api_base.py
+59
-20
python/paddle/utils/code_gen/api_gen.py
python/paddle/utils/code_gen/api_gen.py
+68
-10
python/paddle/utils/code_gen/backward_api_gen.py
python/paddle/utils/code_gen/backward_api_gen.py
+21
-3
python/paddle/utils/code_gen/wrapped_infermeta_gen.py
python/paddle/utils/code_gen/wrapped_infermeta_gen.py
+3
-17
未找到文件。
.gitignore
浏览文件 @
638aab6e
...
@@ -2,16 +2,17 @@ paddle/fluid/operators/distributed/send_recv.proto
...
@@ -2,16 +2,17 @@ paddle/fluid/operators/distributed/send_recv.proto
paddle/fluid/API.spec
paddle/fluid/API.spec
paddle/fluid/API_DEV.spec
paddle/fluid/API_DEV.spec
paddle/fluid/API_PR.spec
paddle/fluid/API_PR.spec
paddle/fluid/eager/api/generated/*
paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/pten/api/backward/backward_api.h
paddle/pten/api/include/api.h
paddle/pten/api/include/api.h
paddle/pten/api/lib/api.cc
paddle/pten/api/lib/api.cc
paddle/pten/api/
backward/backward_api.h
paddle/pten/api/
lib/dygraph_api.*
paddle/pten/api/lib/backward_api.cc
paddle/pten/api/lib/backward_api.cc
paddle/pten/extension.h
paddle/pten/include/*
paddle/pten/include/*
paddle/pten/infermeta/generated.*
paddle/pten/infermeta/generated.*
paddle/pten/extension.h
paddle/fluid/eager/api/generated/*
*.DS_Store
*.DS_Store
*.vs
*.vs
...
...
paddle/pten/api/lib/CMakeLists.txt
浏览文件 @
638aab6e
...
@@ -17,8 +17,12 @@ set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py)
...
@@ -17,8 +17,12 @@ set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py)
set
(
api_yaml_file
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/api.yaml
)
set
(
api_yaml_file
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/api.yaml
)
set
(
api_header_file
${
CMAKE_SOURCE_DIR
}
/paddle/pten/api/include/api.h
)
set
(
api_header_file
${
CMAKE_SOURCE_DIR
}
/paddle/pten/api/include/api.h
)
set
(
api_source_file
${
CMAKE_SOURCE_DIR
}
/paddle/pten/api/lib/api.cc
)
set
(
api_source_file
${
CMAKE_SOURCE_DIR
}
/paddle/pten/api/lib/api.cc
)
set
(
dygraph_api_header_file
${
CMAKE_SOURCE_DIR
}
/paddle/pten/api/lib/dygraph_api.h
)
set
(
dygraph_api_source_file
${
CMAKE_SOURCE_DIR
}
/paddle/pten/api/lib/dygraph_api.cc
)
set
(
api_header_file_tmp
${
api_header_file
}
.tmp
)
set
(
api_header_file_tmp
${
api_header_file
}
.tmp
)
set
(
api_source_file_tmp
${
api_source_file
}
.tmp
)
set
(
api_source_file_tmp
${
api_source_file
}
.tmp
)
set
(
dygraph_api_header_file_tmp
${
dygraph_api_header_file
}
.tmp
)
set
(
dygraph_api_source_file_tmp
${
dygraph_api_source_file
}
.tmp
)
# backward api file
# backward api file
set
(
bw_api_gen_file
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/backward_api_gen.py
)
set
(
bw_api_gen_file
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/backward_api_gen.py
)
...
@@ -40,14 +44,18 @@ endif()
...
@@ -40,14 +44,18 @@ endif()
# generate forward api
# generate forward api
add_custom_command
(
add_custom_command
(
OUTPUT
${
api_header_file
}
${
api_source_file
}
OUTPUT
${
api_header_file
}
${
api_source_file
}
${
dygraph_api_header_file
}
${
dygraph_api_source_file
}
COMMAND
${
PYTHON_EXECUTABLE
}
-m pip install pyyaml
COMMAND
${
PYTHON_EXECUTABLE
}
-m pip install pyyaml
COMMAND
${
PYTHON_EXECUTABLE
}
${
api_gen_file
}
COMMAND
${
PYTHON_EXECUTABLE
}
${
api_gen_file
}
--api_yaml_path
${
api_yaml_file
}
--api_yaml_path
${
api_yaml_file
}
--api_header_path
${
api_header_file_tmp
}
--api_header_path
${
api_header_file_tmp
}
--api_source_path
${
api_source_file_tmp
}
--api_source_path
${
api_source_file_tmp
}
--dygraph_api_header_path
${
dygraph_api_header_file_tmp
}
--dygraph_api_source_path
${
dygraph_api_source_file_tmp
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_header_file_tmp
}
${
api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_header_file_tmp
}
${
api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_source_file_tmp
}
${
api_source_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_source_file_tmp
}
${
api_source_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
dygraph_api_header_file_tmp
}
${
dygraph_api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
dygraph_api_source_file_tmp
}
${
dygraph_api_source_file
}
COMMENT
"copy_if_different
${
api_header_file
}
${
api_source_file
}
"
COMMENT
"copy_if_different
${
api_header_file
}
${
api_source_file
}
"
DEPENDS
${
api_yaml_file
}
${
api_gen_file
}
${
api_gen_base
}
DEPENDS
${
api_yaml_file
}
${
api_gen_file
}
${
api_gen_base
}
VERBATIM
)
VERBATIM
)
...
@@ -86,5 +94,6 @@ cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor_raw)
...
@@ -86,5 +94,6 @@ cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor_raw)
cc_library
(
sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform
)
cc_library
(
sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform
)
cc_library
(
pten_function_api SRCS
${
api_source_file
}
DEPS pten_tensor pten kernel_dispatch pten_data_transform
)
cc_library
(
pten_function_api SRCS
${
api_source_file
}
DEPS pten_tensor pten kernel_dispatch pten_data_transform
)
cc_library
(
pten_dygraph_api SRCS
${
dygraph_api_source_file
}
DEPS pten_tensor pten kernel_dispatch pten_data_transform
)
cc_library
(
pten_bw_function_api SRCS
${
bw_api_source_file
}
DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api
)
cc_library
(
pten_bw_function_api SRCS
${
bw_api_source_file
}
DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api
)
cc_library
(
wrapped_infermeta SRCS
${
wrapped_infermeta_source_file
}
DEPS pten
)
cc_library
(
wrapped_infermeta SRCS
${
wrapped_infermeta_source_file
}
DEPS pten
)
paddle/pten/api/lib/api_utils.h
浏览文件 @
638aab6e
...
@@ -72,11 +72,14 @@ inline pten::MetaTensor MakeMetaTensor(const pten::SelectedRows& tensor) {
...
@@ -72,11 +72,14 @@ inline pten::MetaTensor MakeMetaTensor(const pten::SelectedRows& tensor) {
/* ------------------ for output ----------------------- */
/* ------------------ for output ----------------------- */
inline
pten
::
DenseTensor
*
SetKernelOutput
(
Backend
backend
,
Tensor
*
out
)
{
inline
pten
::
DenseTensor
*
SetKernelOutput
(
Backend
backend
,
Tensor
*
out
)
{
auto
dense_tensor
=
std
::
make_shared
<
pten
::
DenseTensor
>
(
if
(
!
out
->
initialized
())
{
pten
::
make_intrusive
<
SharedStorage
>
(
pten
::
TransToPtenPlace
(
backend
)),
auto
dense_tensor
=
std
::
make_shared
<
pten
::
DenseTensor
>
(
pten
::
DenseTensorMeta
());
pten
::
make_intrusive
<
SharedStorage
>
(
pten
::
TransToPtenPlace
(
backend
)),
out
->
set_impl
(
dense_tensor
);
pten
::
DenseTensorMeta
());
return
dense_tensor
.
get
();
out
->
set_impl
(
dense_tensor
);
return
dense_tensor
.
get
();
}
return
static_cast
<
pten
::
DenseTensor
*>
(
out
->
impl
().
get
());
}
}
inline
std
::
vector
<
pten
::
DenseTensor
*>
SetKernelOutput
(
inline
std
::
vector
<
pten
::
DenseTensor
*>
SetKernelOutput
(
...
@@ -96,9 +99,12 @@ inline std::vector<pten::DenseTensor*> SetKernelOutput(
...
@@ -96,9 +99,12 @@ inline std::vector<pten::DenseTensor*> SetKernelOutput(
inline
pten
::
SelectedRows
*
SetSelectedRowsKernelOutput
(
Backend
backend
,
inline
pten
::
SelectedRows
*
SetSelectedRowsKernelOutput
(
Backend
backend
,
Tensor
*
out
)
{
Tensor
*
out
)
{
auto
select_rows
=
std
::
make_shared
<
pten
::
SelectedRows
>
();
if
(
!
out
->
initialized
())
{
out
->
set_impl
(
select_rows
);
auto
select_rows
=
std
::
make_shared
<
pten
::
SelectedRows
>
();
return
select_rows
.
get
();
out
->
set_impl
(
select_rows
);
return
select_rows
.
get
();
}
return
static_cast
<
pten
::
SelectedRows
*>
(
out
->
impl
().
get
());
}
}
}
// namespace experimental
}
// namespace experimental
...
...
paddle/pten/api/lib/tensor.cc
浏览文件 @
638aab6e
...
@@ -249,10 +249,13 @@ Tensor::data<pten::dtype::bfloat16>() const;
...
@@ -249,10 +249,13 @@ Tensor::data<pten::dtype::bfloat16>() const;
template
<
typename
T
>
template
<
typename
T
>
T
*
Tensor
::
data
()
{
T
*
Tensor
::
data
()
{
PADDLE_THROW
(
pten
::
errors
::
Unimplemented
(
if
(
is_dense_tensor
())
{
"It is not currently supported to directly obtain the modifiable data "
return
std
::
dynamic_pointer_cast
<
pten
::
DenseTensor
>
(
impl_
)
->
data
<
T
>
();
"address through the tensor::data<T>() method, please use the "
}
else
if
(
pten
::
SelectedRows
::
classof
(
impl_
.
get
()))
{
"tensor::mutable_data<T>() method."
));
return
std
::
dynamic_pointer_cast
<
pten
::
SelectedRows
>
(
impl_
)
->
mutable_value
()
->
data
<
T
>
();
}
return
nullptr
;
return
nullptr
;
}
}
...
...
paddle/pten/tests/api/test_reshape_api.cc
浏览文件 @
638aab6e
...
@@ -67,6 +67,25 @@ TEST(API, reshape) {
...
@@ -67,6 +67,25 @@ TEST(API, reshape) {
ASSERT_EQ
(
value_equal
,
true
);
ASSERT_EQ
(
value_equal
,
true
);
}
}
TEST
(
API
,
reshape_
)
{
// 1. create tensor
auto
x
=
paddle
::
experimental
::
full
(
{
3
,
2
,
2
,
3
},
1.0
,
experimental
::
DataType
::
FLOAT32
);
// 2. test API
paddle
::
experimental
::
Tensor
out
=
paddle
::
experimental
::
reshape_
(
x
,
{
12
,
3
});
// 3. check result
std
::
vector
<
int64_t
>
expect_shape
=
{
12
,
3
};
ASSERT_EQ
(
out
.
shape
()[
0
],
expect_shape
[
0
]);
ASSERT_EQ
(
out
.
shape
()[
1
],
expect_shape
[
1
]);
ASSERT_EQ
(
out
.
numel
(),
36
);
ASSERT_EQ
(
out
.
is_cpu
(),
true
);
ASSERT_EQ
(
out
.
type
(),
pten
::
DataType
::
FLOAT32
);
ASSERT_EQ
(
out
.
layout
(),
pten
::
DataLayout
::
NCHW
);
ASSERT_EQ
(
out
.
initialized
(),
true
);
ASSERT_EQ
(
out
.
data
<
float
>
(),
x
.
data
<
float
>
());
}
TEST
(
Tensor
,
old_reshape
)
{
TEST
(
Tensor
,
old_reshape
)
{
paddle
::
experimental
::
Tensor
x
(
paddle
::
PlaceType
::
kCPU
);
paddle
::
experimental
::
Tensor
x
(
paddle
::
PlaceType
::
kCPU
);
x
.
reshape
({
3
,
4
});
x
.
reshape
({
3
,
4
});
...
...
paddle/pten/tests/api/test_scale_api.cc
浏览文件 @
638aab6e
...
@@ -62,7 +62,7 @@ TEST(API, scale_sr) {
...
@@ -62,7 +62,7 @@ TEST(API, scale_sr) {
experimental
::
full
({
3
,
4
},
1.0
,
pten
::
DataType
::
FLOAT32
).
impl
());
experimental
::
full
({
3
,
4
},
1.0
,
pten
::
DataType
::
FLOAT32
).
impl
());
*
(
selected_rows
->
mutable_value
())
=
*
dense_tensor
;
*
(
selected_rows
->
mutable_value
())
=
*
dense_tensor
;
experimental
::
Tensor
x
(
selected_rows
);
experimental
::
Tensor
x
(
selected_rows
);
const
auto
out
=
experimental
::
scale
(
x
,
2.0
,
1.0
,
true
);
auto
out
=
experimental
::
scale
(
x
,
2.0
,
1.0
,
true
);
ASSERT_EQ
(
out
.
dims
().
size
(),
2
);
ASSERT_EQ
(
out
.
dims
().
size
(),
2
);
ASSERT_EQ
(
out
.
dims
()[
0
],
3
);
ASSERT_EQ
(
out
.
dims
()[
0
],
3
);
...
...
python/paddle/utils/code_gen/api.yaml
浏览文件 @
638aab6e
...
@@ -142,11 +142,12 @@
...
@@ -142,11 +142,12 @@
-
api
:
reshape
-
api
:
reshape
args
:
(Tensor x, ScalarArray shape)
args
:
(Tensor x, ScalarArray shape)
output
:
Tensor
output
:
Tensor
(out)
infer_meta
:
infer_meta
:
func
:
ReshapeInferMeta
func
:
ReshapeInferMeta
kernel
:
kernel
:
func
:
reshape
func
:
reshape
inplace
:
(x -> out)
-
api
:
scale
-
api
:
scale
args
:
(Tensor x, Scalar scale, float bias, bool bias_after_scale)
args
:
(Tensor x, Scalar scale, float bias, bool bias_after_scale)
...
...
python/paddle/utils/code_gen/api_base.py
浏览文件 @
638aab6e
...
@@ -48,10 +48,14 @@ class BaseAPI(object):
...
@@ -48,10 +48,14 @@ class BaseAPI(object):
self
.
support_selected_rows_kernel
=
False
if
len
(
self
.
kernel
[
self
.
support_selected_rows_kernel
=
False
if
len
(
self
.
kernel
[
'func'
])
==
1
else
True
'func'
])
==
1
else
True
self
.
data_transform
=
self
.
parse_data_transform
(
api_item_yaml
)
self
.
data_transform
=
self
.
parse_data_transform
(
api_item_yaml
)
self
.
inplace_map
=
self
.
parse_inplace
(
api_item_yaml
)
def
get_api_name
(
self
,
api_item_yaml
):
def
get_api_name
(
self
,
api_item_yaml
):
return
api_item_yaml
[
'api'
]
return
api_item_yaml
[
'api'
]
def
get_api_func_name
(
self
):
return
self
.
api
def
parse_args
(
self
,
api_name
,
api_item_yaml
):
def
parse_args
(
self
,
api_name
,
api_item_yaml
):
inputs
,
attrs
,
args_str
=
self
.
parse_input_and_attr
(
inputs
,
attrs
,
args_str
=
self
.
parse_input_and_attr
(
api_name
,
api_item_yaml
[
'args'
])
api_name
,
api_item_yaml
[
'args'
])
...
@@ -225,13 +229,37 @@ class BaseAPI(object):
...
@@ -225,13 +229,37 @@ class BaseAPI(object):
return
data_transform
return
data_transform
def
parse_inplace
(
self
,
api_item_yaml
):
if
'inplace'
in
api_item_yaml
:
inplace_map
=
{}
inplace_list
=
api_item_yaml
[
'inplace'
].
split
(
','
)
for
item
in
inplace_list
:
result
=
re
.
search
(
r
"(?P<in>\w+)\s*->\s(?P<out>\w+)"
,
item
)
in_val
=
result
.
group
(
'in'
)
out_val
=
result
.
group
(
'out'
)
assert
in_val
in
self
.
inputs
[
'names'
],
\
f
"
{
self
.
api
}
: Inplace input error: the input var name('
{
in_val
}
') is not found in the input args of
{
self
.
api
}
."
assert
out_val
in
self
.
outputs
[
'names'
],
\
f
"
{
self
.
api
}
: Inplace output error: the output var name('
{
out_val
}
') is not found in the output args of
{
self
.
api
}
."
inplace_map
[
out_val
]
=
in_val
return
inplace_map
else
:
return
None
# Override by child class
# Override by child class
def
get_return_type
(
self
,
out_type_list
):
def
get_return_type
(
self
,
out_type_list
):
return
None
return
None
def
gene_api_declaration
(
self
):
def
gene_api_declaration
(
self
):
api_declaration
=
f
"""
api_declaration
=
f
"""
PADDLE_API
{
self
.
outputs
[
'return_type'
]
}
{
self
.
api
}
(
{
self
.
args_str
[
'args_declare'
]
}
);
PADDLE_API
{
self
.
outputs
[
'return_type'
]
}
{
self
.
get_api_func_name
()
}
(
{
self
.
args_str
[
'args_declare'
]
}
);
"""
if
self
.
is_base_api
and
self
.
inplace_map
is
not
None
:
api_declaration
=
api_declaration
+
f
"""
PADDLE_API
{
self
.
outputs
[
'return_type'
]
}
{
self
.
get_api_func_name
()
+
'_'
}
(
{
self
.
args_str
[
'args_declare'
]
}
);
"""
"""
return
api_declaration
return
api_declaration
...
@@ -527,14 +555,18 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
...
@@ -527,14 +555,18 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
return
input_tensor_code
,
kernel_args
[:
-
2
],
kernel_signature
return
input_tensor_code
,
kernel_args
[:
-
2
],
kernel_signature
# Override by child class
# Override by child class
def
gene_output
(
self
,
output_type_list
,
set_out_func
,
code_indent
):
def
gene_output
(
self
,
output_type_list
,
set_out_func
,
code_indent
,
inplace_flag
=
False
):
return
None
,
None
,
None
return
None
,
None
,
None
def
gen_dense_tensor_kernel_code
(
self
,
code_indent
):
def
gen_dense_tensor_kernel_code
(
self
,
code_indent
,
inplace_flag
=
False
):
input_tensors
,
kernel_args
,
kernel_signature
=
self
.
get_kernel_args
(
input_tensors
,
kernel_args
,
kernel_signature
=
self
.
get_kernel_args
(
code_indent
)
code_indent
)
outputs_args
,
kernel_output_names
,
output_create
=
self
.
gene_output
(
outputs_args
,
kernel_output_names
,
output_create
=
self
.
gene_output
(
self
.
outputs
[
'types'
],
'SetKernelOutput'
,
code_indent
)
self
.
outputs
[
'types'
],
'SetKernelOutput'
,
code_indent
,
inplace_flag
)
return
f
"""
return
f
"""
{
code_indent
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
{
code_indent
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
{
code_indent
}
"
{
self
.
kernel
[
'func'
][
0
]
}
", {{kernel_backend, kernel_layout, kernel_data_type}});
{
code_indent
}
"
{
self
.
kernel
[
'func'
][
0
]
}
", {{kernel_backend, kernel_layout, kernel_data_type}});
...
@@ -552,11 +584,12 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
...
@@ -552,11 +584,12 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
{
code_indent
}
return out;"""
{
code_indent
}
return out;"""
def
gen_selected_rows_kernel_code
(
self
,
code_indent
):
def
gen_selected_rows_kernel_code
(
self
,
code_indent
,
inplace_flag
=
False
):
input_tensors
,
kernel_args
,
kernel_signature
=
self
.
get_selected_rows_kernel_args
(
input_tensors
,
kernel_args
,
kernel_signature
=
self
.
get_selected_rows_kernel_args
(
code_indent
)
code_indent
)
outputs_args
,
kernel_output_names
,
output_create
=
self
.
gene_output
(
outputs_args
,
kernel_output_names
,
output_create
=
self
.
gene_output
(
self
.
outputs
[
'types'
],
'SetSelectedRowsKernelOutput'
,
code_indent
)
self
.
outputs
[
'types'
],
'SetSelectedRowsKernelOutput'
,
code_indent
,
inplace_flag
)
return
f
"""
return
f
"""
{
code_indent
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
{
code_indent
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
{
code_indent
}
"
{
self
.
kernel
[
'func'
][
1
]
}
", {{kernel_backend, kernel_layout, kernel_data_type}});
{
code_indent
}
"
{
self
.
kernel
[
'func'
][
1
]
}
", {{kernel_backend, kernel_layout, kernel_data_type}});
...
@@ -574,32 +607,38 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
...
@@ -574,32 +607,38 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
{
code_indent
}
return out;"""
{
code_indent
}
return out;"""
def
gene_
api_code
(
self
):
def
gene_
base_api_code
(
self
,
inplace_flag
=
False
):
if
self
.
is_base_api
:
api_func_name
=
self
.
get_api_func_name
()
+
(
'_'
if
inplace_flag
else
''
)
api_code
=
f
"""
api_code
=
f
"""
PADDLE_API
{
self
.
outputs
[
'return_type'
]
}
{
self
.
api
}
(
{
self
.
args_str
[
"args_define"
]
}
) {{
PADDLE_API
{
self
.
outputs
[
'return_type'
]
}
{
api_func_name
}
(
{
self
.
args_str
[
"args_define"
]
}
) {{
{
self
.
gene_kernel_select
()
}
{
self
.
gene_kernel_select
()
}
"""
"""
if
self
.
support_selected_rows_kernel
:
if
self
.
support_selected_rows_kernel
:
code_indent
=
' '
code_indent
=
' '
api_code
=
api_code
+
f
"""
return
api_code
+
f
"""
if(kernel_type == KernelType::DENSE_TENSOR_KENREL){{
if(kernel_type == KernelType::DENSE_TENSOR_KENREL){{
{
self
.
gen_dense_tensor_kernel_code
(
code_indent
)
}
{
self
.
gen_dense_tensor_kernel_code
(
code_indent
,
inplace_flag
)
}
}} else {{
}} else {{
{
self
.
gen_selected_rows_kernel_code
(
code_indent
)
}
{
self
.
gen_selected_rows_kernel_code
(
code_indent
,
inplace_flag
)
}
}}
}}
}}
}}
"""
"""
return
api_code
else
:
else
:
code_indent
=
''
code_indent
=
''
return
api_code
+
self
.
gen_dense_tensor_kernel_code
(
return
api_code
+
self
.
gen_dense_tensor_kernel_code
(
code_indent
,
inplace_flag
)
+
"""
code_indent
)
+
"""
}
}
"""
"""
def
gene_api_code
(
self
):
if
self
.
is_base_api
:
api_code
=
self
.
gene_base_api_code
()
if
self
.
inplace_map
is
not
None
:
api_code
=
api_code
+
self
.
gene_base_api_code
(
inplace_flag
=
True
)
return
api_code
else
:
else
:
inveke_func_name
=
self
.
invoke
.
split
(
'('
)[
0
].
strip
()
inveke_func_name
=
self
.
invoke
.
split
(
'('
)[
0
].
strip
()
if
inveke_func_name
in
self
.
attrs
[
'names'
]:
if
inveke_func_name
in
self
.
attrs
[
'names'
]:
...
...
python/paddle/utils/code_gen/api_gen.py
浏览文件 @
638aab6e
...
@@ -15,22 +15,38 @@
...
@@ -15,22 +15,38 @@
import
os
import
os
import
yaml
import
yaml
import
argparse
import
argparse
import
re
from
api_base
import
BaseAPI
from
api_base
import
BaseAPI
class
ForwardAPI
(
BaseAPI
):
class
ForwardAPI
(
BaseAPI
):
prefix_tensor_name
=
'dense_'
def
__init__
(
self
,
api_item_yaml
):
def
__init__
(
self
,
api_item_yaml
):
super
(
ForwardAPI
,
self
).
__init__
(
api_item_yaml
)
super
(
ForwardAPI
,
self
).
__init__
(
api_item_yaml
)
self
.
is_dygraph_api
=
self
.
parse_intermediate
(
api_item_yaml
)
def
get_api_func_name
(
self
):
if
self
.
is_dygraph_api
:
return
self
.
api
+
'_intermediate'
else
:
return
self
.
api
def
parse_intermediate
(
self
,
api_item_yaml
):
if
'intermediate'
in
api_item_yaml
:
return
True
else
:
return
False
def
get_return_type
(
self
,
out_type_list
):
def
get_return_type
(
self
,
out_type_list
):
return
out_type_list
[
0
]
if
len
(
return
out_type_list
[
0
]
if
len
(
out_type_list
)
==
1
else
"std::tuple<"
+
","
.
join
(
out_type_list
)
==
1
else
"std::tuple<"
+
","
.
join
(
out_type_list
)
+
">"
out_type_list
)
+
">"
def
gene_output
(
self
,
output_type_list
,
set_out_func
,
code_indent
):
def
gene_output
(
self
,
output_type_list
,
set_out_func
,
code_indent
,
inplace_flag
=
False
):
kernel_output
=
""
kernel_output
=
""
output_names
=
[]
output_names
=
[]
output_create
=
""
output_create
=
""
...
@@ -38,8 +54,11 @@ class ForwardAPI(BaseAPI):
...
@@ -38,8 +54,11 @@ class ForwardAPI(BaseAPI):
if
len
(
output_type_list
)
==
1
:
if
len
(
output_type_list
)
==
1
:
kernel_output
=
'kernel_out'
kernel_output
=
'kernel_out'
output_names
.
append
(
'kernel_out'
)
output_names
.
append
(
'kernel_out'
)
inplace_assign
=
" = "
+
self
.
inplace_map
[
self
.
outputs
[
'names'
][
0
]]
if
inplace_flag
and
self
.
inplace_map
is
not
None
and
self
.
outputs
[
'names'
][
0
]
in
self
.
inplace_map
else
""
output_create
=
f
"""
output_create
=
f
"""
{
code_indent
}
{
self
.
outputs
[
'return_type'
]
}
out;
{
code_indent
}
{
self
.
outputs
[
'return_type'
]
}
out
{
inplace_assign
}
;
{
code_indent
}
auto kernel_out =
{
set_out_func
}
(kernel_backend, &out);"""
{
code_indent
}
auto kernel_out =
{
set_out_func
}
(kernel_backend, &out);"""
elif
len
(
output_type_list
)
>
1
:
elif
len
(
output_type_list
)
>
1
:
...
@@ -49,6 +68,11 @@ class ForwardAPI(BaseAPI):
...
@@ -49,6 +68,11 @@ class ForwardAPI(BaseAPI):
for
i
in
range
(
len
(
output_type_list
)):
for
i
in
range
(
len
(
output_type_list
)):
kernel_output
=
kernel_output
+
f
'kernel_out_
{
i
}
, '
kernel_output
=
kernel_output
+
f
'kernel_out_
{
i
}
, '
output_names
.
append
(
f
'kernel_out_
{
i
}
'
)
output_names
.
append
(
f
'kernel_out_
{
i
}
'
)
if
inplace_flag
and
self
.
inplace_map
is
not
None
and
self
.
outputs
[
'names'
][
i
]
in
self
.
inplace_map
:
output_create
=
output_create
+
f
"""
{
code_indent
}
std::get<
{
i
}
>(out) =
{
self
.
inplace_map
[
self
.
outputs
[
'names'
][
i
]]
}
;"""
output_create
=
output_create
+
f
"""
output_create
=
output_create
+
f
"""
{
code_indent
}
auto kernel_out_
{
i
}
=
{
set_out_func
}
(kernel_backend, &std::get<
{
i
}
>(out));"""
{
code_indent
}
auto kernel_out_
{
i
}
=
{
set_out_func
}
(kernel_backend, &std::get<
{
i
}
>(out));"""
...
@@ -110,12 +134,15 @@ namespace experimental {
...
@@ -110,12 +134,15 @@ namespace experimental {
"""
)
"""
)
def
generate_api
(
api_yaml_path
,
header_file_path
,
source_file_path
):
def
generate_api
(
api_yaml_path
,
header_file_path
,
source_file_path
,
dygraph_header_file_path
,
dygraph_source_file_path
):
with
open
(
api_yaml_path
,
'r'
)
as
f
:
with
open
(
api_yaml_path
,
'r'
)
as
f
:
apis
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
apis
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
header_file
=
open
(
header_file_path
,
'w'
)
header_file
=
open
(
header_file_path
,
'w'
)
source_file
=
open
(
source_file_path
,
'w'
)
source_file
=
open
(
source_file_path
,
'w'
)
dygraph_header_file
=
open
(
dygraph_header_file_path
,
'w'
)
dygraph_source_file
=
open
(
dygraph_source_file_path
,
'w'
)
namespace
=
api_namespace
()
namespace
=
api_namespace
()
...
@@ -127,20 +154,37 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
...
@@ -127,20 +154,37 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file
.
write
(
source_include
(
include_header_file
))
source_file
.
write
(
source_include
(
include_header_file
))
source_file
.
write
(
namespace
[
0
])
source_file
.
write
(
namespace
[
0
])
dygraph_header_file
.
write
(
"#pragma once
\n
"
)
dygraph_header_file
.
write
(
header_include
())
dygraph_header_file
.
write
(
namespace
[
0
])
dygraph_include_header_file
=
"paddle/pten/api/lib/dygraph_api.h"
dygraph_source_file
.
write
(
source_include
(
dygraph_include_header_file
))
dygraph_source_file
.
write
(
namespace
[
0
])
for
api
in
apis
:
for
api
in
apis
:
api_code
=
ForwardAPI
(
api
)
foward_api
=
ForwardAPI
(
api
)
print
(
api_code
.
gene_api_declaration
())
if
foward_api
.
is_dygraph_api
:
header_file
.
write
(
api_code
.
gene_api_declaration
())
dygraph_header_file
.
write
(
foward_api
.
gene_api_declaration
())
source_file
.
write
(
api_code
.
gene_api_code
())
dygraph_source_file
.
write
(
foward_api
.
gene_api_code
())
else
:
header_file
.
write
(
foward_api
.
gene_api_declaration
())
source_file
.
write
(
foward_api
.
gene_api_code
())
header_file
.
write
(
namespace
[
1
])
header_file
.
write
(
namespace
[
1
])
source_file
.
write
(
namespace
[
1
])
source_file
.
write
(
namespace
[
1
])
dygraph_header_file
.
write
(
namespace
[
1
])
dygraph_source_file
.
write
(
namespace
[
1
])
source_file
.
write
(
api_register
())
source_file
.
write
(
api_register
())
header_file
.
close
()
header_file
.
close
()
source_file
.
close
()
source_file
.
close
()
dygraph_header_file
.
close
()
dygraph_source_file
.
close
()
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -149,6 +193,7 @@ def main():
...
@@ -149,6 +193,7 @@ def main():
'--api_yaml_path'
,
'--api_yaml_path'
,
help
=
'path to api yaml file'
,
help
=
'path to api yaml file'
,
default
=
'python/paddle/utils/code_gen/api.yaml'
)
default
=
'python/paddle/utils/code_gen/api.yaml'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--api_header_path'
,
'--api_header_path'
,
help
=
'output of generated api header code file'
,
help
=
'output of generated api header code file'
,
...
@@ -159,13 +204,26 @@ def main():
...
@@ -159,13 +204,26 @@ def main():
help
=
'output of generated api source code file'
,
help
=
'output of generated api source code file'
,
default
=
'paddle/pten/api/lib/api.cc'
)
default
=
'paddle/pten/api/lib/api.cc'
)
parser
.
add_argument
(
'--dygraph_api_header_path'
,
help
=
'output of generated dygraph api header code file'
,
default
=
'paddle/pten/api/lib/dygraph_api.h'
)
parser
.
add_argument
(
'--dygraph_api_source_path'
,
help
=
'output of generated dygraph api source code file'
,
default
=
'paddle/pten/api/lib/dygraph_api.cc'
)
options
=
parser
.
parse_args
()
options
=
parser
.
parse_args
()
api_yaml_path
=
options
.
api_yaml_path
api_yaml_path
=
options
.
api_yaml_path
header_file_path
=
options
.
api_header_path
header_file_path
=
options
.
api_header_path
source_file_path
=
options
.
api_source_path
source_file_path
=
options
.
api_source_path
dygraph_header_file_path
=
options
.
dygraph_api_header_path
dygraph_source_file_path
=
options
.
dygraph_api_source_path
generate_api
(
api_yaml_path
,
header_file_path
,
source_file_path
)
generate_api
(
api_yaml_path
,
header_file_path
,
source_file_path
,
dygraph_header_file_path
,
dygraph_source_file_path
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/utils/code_gen/backward_api_gen.py
浏览文件 @
638aab6e
...
@@ -69,7 +69,11 @@ class BackwardAPI(BaseAPI):
...
@@ -69,7 +69,11 @@ class BackwardAPI(BaseAPI):
return
out_type_list
[
0
]
if
len
(
return
out_type_list
[
0
]
if
len
(
out_type_list
)
==
1
else
"std::vector<std::vector<Tensor>>"
out_type_list
)
==
1
else
"std::vector<std::vector<Tensor>>"
def
gene_output
(
self
,
output_type_list
,
set_out_func
,
code_indent
):
def
gene_output
(
self
,
output_type_list
,
set_out_func
,
code_indent
,
inplace_flag
=
False
):
kernel_output
=
""
kernel_output
=
""
output_names
=
[]
output_names
=
[]
output_create
=
""
output_create
=
""
...
@@ -77,8 +81,11 @@ class BackwardAPI(BaseAPI):
...
@@ -77,8 +81,11 @@ class BackwardAPI(BaseAPI):
if
len
(
output_type_list
)
==
1
:
if
len
(
output_type_list
)
==
1
:
kernel_output
=
'kernel_out'
kernel_output
=
'kernel_out'
output_names
.
append
(
'kernel_out'
)
output_names
.
append
(
'kernel_out'
)
inplace_assign
=
" = "
+
self
.
inplace_map
[
self
.
outputs
[
'names'
][
0
]]
if
inplace_flag
and
self
.
inplace_map
is
not
None
and
self
.
outputs
[
'names'
][
0
]
in
self
.
inplace_map
else
""
output_create
=
f
"""
output_create
=
f
"""
{
code_indent
}
{
self
.
outputs
[
'return_type'
]
}
out;
{
code_indent
}
{
self
.
outputs
[
'return_type'
]
}
out
{
inplace_assign
}
;
{
code_indent
}
auto kernel_out =
{
set_out_func
}
(kernel_backend, &out);"""
{
code_indent
}
auto kernel_out =
{
set_out_func
}
(kernel_backend, &out);"""
elif
len
(
output_type_list
)
>
1
:
elif
len
(
output_type_list
)
>
1
:
...
@@ -90,11 +97,22 @@ class BackwardAPI(BaseAPI):
...
@@ -90,11 +97,22 @@ class BackwardAPI(BaseAPI):
output_names
.
append
(
f
'kernel_out_
{
i
}
'
)
output_names
.
append
(
f
'kernel_out_
{
i
}
'
)
if
out_type_item
==
'Tensor'
:
if
out_type_item
==
'Tensor'
:
get_out_code
=
f
'&out[
{
i
}
][0]'
get_out_code
=
f
'&out[
{
i
}
][0]'
output_create
=
output_create
+
f
"""
if
inplace_flag
and
self
.
inplace_map
is
not
None
and
self
.
outputs
[
'names'
][
i
]
in
self
.
inplace_map
:
output_create
=
output_create
+
f
"""
{
code_indent
}
out[
{
i
}
].emplace_back(
{
self
.
inplace_map
[
self
.
outputs
[
'names'
][
i
]]
}
);"""
else
:
output_create
=
output_create
+
f
"""
{
code_indent
}
out[
{
i
}
].emplace_back();"""
{
code_indent
}
out[
{
i
}
].emplace_back();"""
else
:
else
:
get_out_code
=
f
'&out[
{
i
}
]'
get_out_code
=
f
'&out[
{
i
}
]'
if
inplace_flag
and
self
.
inplace_map
is
not
None
and
self
.
outputs
[
'names'
][
i
]
in
self
.
inplace_map
:
output_create
=
output_create
+
f
"""
{
code_indent
}
out[
{
i
}
] =
{
self
.
inplace_map
[
self
.
outputs
[
'names'
][
i
]]
}
;"""
output_create
=
output_create
+
f
"""
output_create
=
output_create
+
f
"""
{
code_indent
}
auto kernel_out_
{
i
}
=
{
set_out_func
}
(kernel_backend,
{
get_out_code
}
);"""
{
code_indent
}
auto kernel_out_
{
i
}
=
{
set_out_func
}
(kernel_backend,
{
get_out_code
}
);"""
...
...
python/paddle/utils/code_gen/wrapped_infermeta_gen.py
浏览文件 @
638aab6e
...
@@ -16,7 +16,7 @@ import os
...
@@ -16,7 +16,7 @@ import os
import
yaml
import
yaml
import
argparse
import
argparse
from
api_
base
import
Base
API
from
api_
gen
import
Forward
API
def
get_wrapped_infermeta_name
(
api_name
):
def
get_wrapped_infermeta_name
(
api_name
):
...
@@ -24,7 +24,7 @@ def get_wrapped_infermeta_name(api_name):
...
@@ -24,7 +24,7 @@ def get_wrapped_infermeta_name(api_name):
def
gene_wrapped_infermeta_and_register
(
api
):
def
gene_wrapped_infermeta_and_register
(
api
):
if
api
.
is_base_api
:
if
api
.
is_base_api
and
not
api
.
is_dygraph_api
:
register_code
=
f
"""
register_code
=
f
"""
PT_REGISTER_INFER_META_FN(
{
api
.
kernel
[
'func'
][
0
]
}
, pten::
{
api
.
infer_meta
[
'func'
]
}
);"""
PT_REGISTER_INFER_META_FN(
{
api
.
kernel
[
'func'
][
0
]
}
, pten::
{
api
.
infer_meta
[
'func'
]
}
);"""
...
@@ -76,20 +76,6 @@ PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_
...
@@ -76,20 +76,6 @@ PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_
return
''
,
''
,
''
return
''
,
''
,
''
def
gene_infermeta_register
(
api
):
if
api
.
is_base_api
:
if
api
.
infer_meta
[
'param'
]
is
None
:
return
f
"""
PT_REGISTER_INFER_META_FN(
{
api
.
kernel
[
'func'
][
0
]
}
, pten::
{
api
.
infer_meta
[
'func'
]
}
);"""
else
:
return
f
"""
PT_REGISTER_INFER_META_FN(
{
api
.
kernel
[
'func'
][
0
]
}
, pten::
{
get_wrapped_infermeta_name
(
api
.
kernel
[
'func'
][
0
])
}
);"""
else
:
return
''
def
header_include
():
def
header_include
():
return
"""
return
"""
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/core/meta_tensor.h"
...
@@ -138,7 +124,7 @@ def generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path,
...
@@ -138,7 +124,7 @@ def generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path,
infermeta_register_code
=
''
infermeta_register_code
=
''
for
api
in
apis
:
for
api
in
apis
:
api_item
=
Base
API
(
api
)
api_item
=
Forward
API
(
api
)
declare_code
,
defind_code
,
register_code
=
gene_wrapped_infermeta_and_register
(
declare_code
,
defind_code
,
register_code
=
gene_wrapped_infermeta_and_register
(
api_item
)
api_item
)
header_file
.
write
(
declare_code
)
header_file
.
write
(
declare_code
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录