Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
82cf1fad
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82cf1fad
编写于
2月 12, 2023
作者:
X
Xiaoxu Chen
提交者:
GitHub
2月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[prim] generate static prim api (#50315)
上级
14e45f6b
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
659 addition
and
200 deletion
+659
-200
paddle/fluid/operators/generator/filters.py
paddle/fluid/operators/generator/filters.py
+5
-1
paddle/fluid/operators/generator/tests.py
paddle/fluid/operators/generator/tests.py
+12
-0
paddle/fluid/prim/api/CMakeLists.txt
paddle/fluid/prim/api/CMakeLists.txt
+5
-0
paddle/fluid/prim/api/api.yaml
paddle/fluid/prim/api/api.yaml
+25
-0
paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt
paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt
+45
-8
paddle/fluid/prim/api/auto_code_generated/eager_gen.py
paddle/fluid/prim/api/auto_code_generated/eager_gen.py
+15
-2
paddle/fluid/prim/api/auto_code_generated/prim_base.py
paddle/fluid/prim/api/auto_code_generated/prim_base.py
+4
-4
paddle/fluid/prim/api/auto_code_generated/static_gen.py
paddle/fluid/prim/api/auto_code_generated/static_gen.py
+157
-0
paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.tpl
...m/api/auto_code_generated/template/static_prim_api.cc.tpl
+39
-0
paddle/fluid/prim/api/auto_code_generated/template/utils.tpl
paddle/fluid/prim/api/auto_code_generated/template/utils.tpl
+175
-0
paddle/fluid/prim/api/generated_prim/CMakeLists.txt
paddle/fluid/prim/api/generated_prim/CMakeLists.txt
+5
-1
paddle/fluid/prim/api/manual_prim/CMakeLists.txt
paddle/fluid/prim/api/manual_prim/CMakeLists.txt
+7
-1
paddle/fluid/prim/api/manual_prim/eager_prim_api.cc
paddle/fluid/prim/api/manual_prim/eager_prim_api.cc
+37
-0
paddle/fluid/prim/api/manual_prim/prim_manual_api.h
paddle/fluid/prim/api/manual_prim/prim_manual_api.h
+17
-4
paddle/fluid/prim/api/manual_prim/static_prim_api.cc
paddle/fluid/prim/api/manual_prim/static_prim_api.cc
+6
-164
paddle/fluid/prim/api/manual_prim/utils/utils.h
paddle/fluid/prim/api/manual_prim/utils/utils.h
+6
-0
paddle/phi/api/yaml/op_compat.yaml
paddle/phi/api/yaml/op_compat.yaml
+99
-15
未找到文件。
paddle/fluid/operators/generator/filters.py
浏览文件 @
82cf1fad
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
itertools
import
itertools
import
re
import
re
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Sequence
from
type_mapping
import
(
from
type_mapping
import
(
attr_types_map
,
attr_types_map
,
...
@@ -80,6 +80,10 @@ def to_sr_output_type(s):
...
@@ -80,6 +80,10 @@ def to_sr_output_type(s):
return
sr_output_types_map
[
s
]
return
sr_output_types_map
[
s
]
def
filter_intermediate
(
items
:
Sequence
):
return
tuple
([
item
for
item
in
items
if
not
item
.
get
(
'intermediate'
)])
# -------------- transform argument names from yaml to opmaker ------------
# -------------- transform argument names from yaml to opmaker ------------
def
to_opmaker_name
(
s
):
def
to_opmaker_name
(
s
):
if
s
.
endswith
(
"_grad"
):
if
s
.
endswith
(
"_grad"
):
...
...
paddle/fluid/operators/generator/tests.py
浏览文件 @
82cf1fad
...
@@ -38,6 +38,14 @@ def is_scalar(s):
...
@@ -38,6 +38,14 @@ def is_scalar(s):
return
re
.
match
(
r
"Scalar(\(\w+\))*"
,
s
)
is
not
None
return
re
.
match
(
r
"Scalar(\(\w+\))*"
,
s
)
is
not
None
def
is_intarray
(
s
):
return
s
==
'IntArray'
def
is_datatype
(
s
):
return
s
==
'DataType'
def
is_initializer_list
(
s
):
def
is_initializer_list
(
s
):
return
s
==
"{}"
return
s
==
"{}"
...
@@ -63,3 +71,7 @@ def supports_no_need_buffer(op):
...
@@ -63,3 +71,7 @@ def supports_no_need_buffer(op):
if
input
[
"no_need_buffer"
]:
if
input
[
"no_need_buffer"
]:
return
True
return
True
return
False
return
False
def
is_tensor_list
(
s
):
return
s
==
'Tensor[]'
paddle/fluid/prim/api/CMakeLists.txt
浏览文件 @
82cf1fad
...
@@ -3,11 +3,16 @@ add_subdirectory(manual_prim)
...
@@ -3,11 +3,16 @@ add_subdirectory(manual_prim)
add_subdirectory
(
generated_prim
)
add_subdirectory
(
generated_prim
)
if
(
NOT
(
NOT WITH_PYTHON AND ON_INFER
))
if
(
NOT
(
NOT WITH_PYTHON AND ON_INFER
))
cc_library
(
eager_prim_api DEPS generated_eager_prim_api manual_eager_prim_api
)
cc_library
(
static_prim_api DEPS generated_static_prim_api
manual_static_prim_api
)
cc_library
(
cc_library
(
prim_api
prim_api
SRCS all.cc
SRCS all.cc
DEPS static_utils static_prim_api eager_prim_api eager_api
)
DEPS static_utils static_prim_api eager_prim_api eager_api
)
else
()
else
()
cc_library
(
static_prim_api DEPS generated_static_prim_api
manual_static_prim_api
)
cc_library
(
cc_library
(
prim_api
prim_api
SRCS all.cc
SRCS all.cc
...
...
paddle/fluid/prim/api/api.yaml
0 → 100644
浏览文件 @
82cf1fad
-
unsqueeze
-
pow
-
exp
-
scale
-
multiply
-
matmul
-
expand
-
divide
-
sum
-
add
-
abs
-
assign
-
concat
-
elementwise_pow
-
floor
-
gather_nd
-
log
-
max
-
maximum
-
minimum
-
prod
-
roll
-
scatter
-
scatter_nd_add
-
tile
paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt
浏览文件 @
82cf1fad
...
@@ -4,36 +4,73 @@ set(api_yaml_path
...
@@ -4,36 +4,73 @@ set(api_yaml_path
set
(
legacy_api_yaml_path
set
(
legacy_api_yaml_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
)
)
set
(
api_compat_yaml_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/phi/api/yaml/op_compat.yaml"
)
set
(
api_prim_yaml_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/api.yaml"
)
set
(
api_version_yaml_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/phi/api/yaml/op_version.yaml"
)
set
(
tmp_eager_prim_api_cc_path
set
(
tmp_eager_prim_api_cc_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/tmp_eager_prim_api.cc"
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc.tmp"
)
set
(
tmp_static_prim_api_cc_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/static_prim_api.cc.tmp"
)
)
set
(
tmp_prim_api_h_path
set
(
tmp_prim_api_h_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/
tmp_prim_generated_api.h
"
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/
prim_generated_api.h.tmp
"
)
)
set
(
eager_prim_api_cc_path
set
(
eager_prim_api_cc_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc"
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc"
)
)
set
(
static_prim_api_cc_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/static_prim_api.cc"
)
set
(
prim_api_h_path
set
(
prim_api_h_path
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
)
)
set
(
prim_api_gen_file
set
(
static_prim_api_template_path
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/auto_code_generated/prim_gen.py
)
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.tpl"
)
set
(
eager_prim_api_gen_file
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/auto_code_generated/eager_gen.py
)
set
(
static_prim_api_gen_file
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/prim/api/auto_code_generated/static_gen.py
)
message
(
"
prim api Code gen
"
)
message
(
"
Eager prim api code generator
"
)
execute_process
(
execute_process
(
WORKING_DIRECTORY
WORKING_DIRECTORY
${
CMAKE_SOURCE_DIR
}
/paddle/fluid/prim/api/auto_code_generated
${
CMAKE_SOURCE_DIR
}
/paddle/fluid/prim/api/auto_code_generated
COMMAND
COMMAND
${
PYTHON_EXECUTABLE
}
${
prim_api_gen_file
}
--api_yaml_path
${
PYTHON_EXECUTABLE
}
${
eager_
prim_api_gen_file
}
--api_yaml_path
${
legacy_api_yaml_path
}
${
api_yaml_path
}
--prim_api_header_path
${
legacy_api_yaml_path
}
${
api_yaml_path
}
--prim_api_header_path
${
tmp_prim_api_h_path
}
--eager_prim_api_source_path
${
tmp_prim_api_h_path
}
--eager_prim_api_source_path
${
tmp_eager_prim_api_cc_path
}
${
tmp_eager_prim_api_cc_path
}
--api_prim_yaml_path
${
api_prim_yaml_path
}
RESULT_VARIABLE _result
)
RESULT_VARIABLE _result
)
if
(
${
_result
}
)
if
(
${
_result
}
)
message
(
FATAL_ERROR
"
prim api gen
rate failed, exiting."
)
message
(
FATAL_ERROR
"
Eager prim api gene
rate failed, exiting."
)
endif
()
endif
()
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_prim_api_h_path
}
${
prim_api_h_path
}
)
${
tmp_prim_api_h_path
}
${
prim_api_h_path
}
)
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_eager_prim_api_cc_path
}
${
eager_prim_api_cc_path
}
)
${
tmp_eager_prim_api_cc_path
}
${
eager_prim_api_cc_path
}
)
message
(
"copy tmp_xxx_prim_api to xxx_prim_api"
)
message
(
"copy tmp_xxx_prim_api to xxx_prim_api"
)
message
(
"Static prim api code generator"
)
execute_process
(
WORKING_DIRECTORY
${
CMAKE_SOURCE_DIR
}
/paddle/fluid/prim/api/auto_code_generated
COMMAND
${
PYTHON_EXECUTABLE
}
${
static_prim_api_gen_file
}
--api_phi_yaml_path
${
api_yaml_path
}
--api_phi_legacy_yaml_path
${
legacy_api_yaml_path
}
--api_compat_yaml_path
${
api_compat_yaml_path
}
--api_version_yaml_path
${
api_version_yaml_path
}
--api_prim_yaml_path
${
api_prim_yaml_path
}
--template_path
${
static_prim_api_template_path
}
--output_path
${
tmp_static_prim_api_cc_path
}
RESULT_VARIABLE _result
)
if
(
${
_result
}
)
message
(
FATAL_ERROR
"Static prim api generate failed, exiting."
)
endif
()
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
tmp_static_prim_api_cc_path
}
${
static_prim_api_cc_path
}
)
message
(
"copy tmp_xxx_prim_api to xxx_prim_api"
)
paddle/fluid/prim/api/auto_code_generated/
prim
_gen.py
→
paddle/fluid/prim/api/auto_code_generated/
eager
_gen.py
浏览文件 @
82cf1fad
...
@@ -55,7 +55,9 @@ using DataType = paddle::experimental::DataType;
...
@@ -55,7 +55,9 @@ using DataType = paddle::experimental::DataType;
)
)
def
generate_api
(
api_yaml_path
,
header_file_path
,
eager_prim_source_file_path
):
def
generate_api
(
api_yaml_path
,
header_file_path
,
eager_prim_source_file_path
,
api_prim_path
):
apis
=
[]
apis
=
[]
for
each_api_yaml
in
api_yaml_path
:
for
each_api_yaml
in
api_yaml_path
:
...
@@ -76,8 +78,11 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path):
...
@@ -76,8 +78,11 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path):
eager_prim_source_file
.
write
(
eager_source_include
())
eager_prim_source_file
.
write
(
eager_source_include
())
eager_prim_source_file
.
write
(
namespace
[
0
])
eager_prim_source_file
.
write
(
namespace
[
0
])
with
open
(
api_prim_path
,
'rt'
)
as
f
:
api_prims
=
yaml
.
safe_load
(
f
)
for
api
in
apis
:
for
api
in
apis
:
prim_api
=
EagerPrimAPI
(
api
)
prim_api
=
EagerPrimAPI
(
api
,
api_prims
)
if
prim_api
.
is_prim_api
:
if
prim_api
.
is_prim_api
:
header_file
.
write
(
prim_api
.
gene_prim_api_declaration
())
header_file
.
write
(
prim_api
.
gene_prim_api_declaration
())
eager_prim_source_file
.
write
(
prim_api
.
gene_eager_prim_api_code
())
eager_prim_source_file
.
write
(
prim_api
.
gene_eager_prim_api_code
())
...
@@ -112,16 +117,24 @@ def main():
...
@@ -112,16 +117,24 @@ def main():
default
=
'paddle/fluid/prim/api/generated_prim/eager_prim_api.cc'
,
default
=
'paddle/fluid/prim/api/generated_prim/eager_prim_api.cc'
,
)
)
parser
.
add_argument
(
'--api_prim_yaml_path'
,
help
=
'Primitive API list yaml file.'
,
default
=
'paddle/fluid/prim/api/auto_code_generated/api.yaml'
,
)
options
=
parser
.
parse_args
()
options
=
parser
.
parse_args
()
api_yaml_path
=
options
.
api_yaml_path
api_yaml_path
=
options
.
api_yaml_path
prim_api_header_file_path
=
options
.
prim_api_header_path
prim_api_header_file_path
=
options
.
prim_api_header_path
eager_prim_api_source_file_path
=
options
.
eager_prim_api_source_path
eager_prim_api_source_file_path
=
options
.
eager_prim_api_source_path
api_prim_yaml_path
=
options
.
api_prim_yaml_path
generate_api
(
generate_api
(
api_yaml_path
,
api_yaml_path
,
prim_api_header_file_path
,
prim_api_header_file_path
,
eager_prim_api_source_file_path
,
eager_prim_api_source_file_path
,
api_prim_yaml_path
,
)
)
...
...
paddle/fluid/prim/api/auto_code_generated/prim_base.py
浏览文件 @
82cf1fad
...
@@ -39,12 +39,12 @@ inplace_optional_out_type_map = {
...
@@ -39,12 +39,12 @@ inplace_optional_out_type_map = {
class
BaseAPI
:
class
BaseAPI
:
def
__init__
(
self
,
api_item_yaml
):
def
__init__
(
self
,
api_item_yaml
,
prims
=
tuple
()
):
# self.api = api_item_yaml['op']
# self.api = api_item_yaml['op']
self
.
api
=
api_item_yaml
[
'name'
]
self
.
api
=
api_item_yaml
[
'name'
]
self
.
is_prim_api
=
False
self
.
is_prim_api
=
False
if
api_item_yaml
[
'name'
]
in
white_ops_list
:
if
api_item_yaml
[
'name'
]
in
prims
:
self
.
is_prim_api
=
True
self
.
is_prim_api
=
True
#######################################
#######################################
...
@@ -253,8 +253,8 @@ class BaseAPI:
...
@@ -253,8 +253,8 @@ class BaseAPI:
class
EagerPrimAPI
(
BaseAPI
):
class
EagerPrimAPI
(
BaseAPI
):
def
__init__
(
self
,
api_item_yaml
):
def
__init__
(
self
,
api_item_yaml
,
prims
=
tuple
()
):
super
().
__init__
(
api_item_yaml
)
super
().
__init__
(
api_item_yaml
,
prims
)
def
get_api__func_name
(
self
):
def
get_api__func_name
(
self
):
api_func_name
=
self
.
api
api_func_name
=
self
.
api
...
...
paddle/fluid/prim/api/auto_code_generated/static_gen.py
0 → 100644
浏览文件 @
82cf1fad
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
copy
import
pathlib
import
sys
import
jinja2
import
yaml
# fmt: off
# import from paddle/fluid/operators/generator
sys
.
path
.
append
(
str
(
pathlib
.
Path
(
__file__
).
parents
[
3
].
joinpath
(
'operators/generator'
))
)
import
filters
as
op_gen_filters
import
generate_op
as
op_gen_utils
import
parse_utils
as
op_gen_parse_utils
import
tests
as
op_gen_tests
# fmt: on
def
load_yaml
(
path
,
mode
=
"rt"
):
with
open
(
path
,
mode
)
as
f
:
return
yaml
.
safe_load
(
f
)
def
render
(
tpl
,
*
args
,
**
kwargs
):
env
=
jinja2
.
Environment
(
loader
=
jinja2
.
FileSystemLoader
(
pathlib
.
Path
(
tpl
).
parent
),
keep_trailing_newline
=
True
,
trim_blocks
=
True
,
lstrip_blocks
=
True
,
undefined
=
jinja2
.
StrictUndefined
,
extensions
=
[
'jinja2.ext.do'
],
)
env
.
filters
.
update
(
{
'to_paddle_attr_type'
:
op_gen_filters
.
to_paddle_attr_type
,
'to_paddle_input_type'
:
op_gen_filters
.
to_paddle_input_type
,
'to_paddle_output_type'
:
op_gen_filters
.
to_paddle_output_type
,
'to_pascal'
:
op_gen_filters
.
to_pascal_case
,
"trip_intermediate"
:
op_gen_filters
.
filter_intermediate
,
}
)
env
.
tests
.
update
(
{
'scalar'
:
op_gen_tests
.
is_scalar
,
'intarray'
:
op_gen_tests
.
is_intarray
,
'datatype'
:
op_gen_tests
.
is_datatype
,
'tensor_sequence'
:
op_gen_tests
.
is_tensor_list
,
}
)
return
env
.
get_template
(
pathlib
.
Path
(
tpl
).
name
).
render
(
*
args
,
**
kwargs
)
def
filter_prim
(
apis
,
prims
):
return
[
api
for
api
in
apis
if
api
.
get
(
'name'
)
in
prims
]
def
extend_compat
(
apis
,
compats
):
dicts
=
op_gen_parse_utils
.
to_named_dict
(
copy
.
deepcopy
(
apis
))
for
api
in
dicts
.
values
():
op_gen_utils
.
restruct_io
(
api
)
api
[
'op_name'
]
=
api
[
'name'
]
op_gen_utils
.
add_fluid_name
(
api
[
'inputs'
])
op_gen_utils
.
add_fluid_name
(
api
[
'attrs'
])
op_gen_utils
.
add_fluid_name
(
api
[
'outputs'
])
api
[
'backward'
]
=
None
op_gen_utils
.
add_compat_name
(
compats
,
dicts
,
{})
return
tuple
(
dicts
.
values
())
def
extend_version
(
apis
,
versions
):
apis
=
copy
.
deepcopy
(
apis
)
for
api
in
apis
:
for
version
in
versions
:
if
version
.
get
(
'op'
)
==
api
.
get
(
'name'
):
api
[
'version'
]
=
version
[
'version'
]
return
apis
def
generate
(
api_prim_yaml_path
,
api_phi_yaml_path
,
api_phi_legacy_yaml_path
,
api_compat_yaml_path
,
api_version_yaml_path
,
template_path
,
output_op_path
,
):
prims
,
phis
,
legacy_phis
,
compats
,
versions
=
(
load_yaml
(
api_prim_yaml_path
),
load_yaml
(
api_phi_yaml_path
),
load_yaml
(
api_phi_legacy_yaml_path
),
load_yaml
(
api_compat_yaml_path
),
load_yaml
(
api_version_yaml_path
),
)
apis
=
phis
+
legacy_phis
apis
=
filter_prim
(
apis
,
prims
)
apis
=
extend_version
(
apis
,
versions
)
apis
=
extend_compat
(
apis
,
compats
)
if
len
(
apis
)
>
0
:
with
open
(
output_op_path
,
"wt"
)
as
f
:
msg
=
render
(
template_path
,
apis
=
apis
)
f
.
write
(
msg
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Generate Static Primitive API"
)
parser
.
add_argument
(
'--api_prim_yaml_path'
,
type
=
str
,
help
=
"Primitive API yaml file.."
)
parser
.
add_argument
(
'--api_phi_yaml_path'
,
type
=
str
,
help
=
"Parsed ops yaml file."
)
parser
.
add_argument
(
'--api_phi_legacy_yaml_path'
,
type
=
str
,
help
=
"Parsed ops yaml file."
)
parser
.
add_argument
(
'--api_compat_yaml_path'
,
type
=
str
,
help
=
"Ops args compat yaml file."
)
parser
.
add_argument
(
'--api_version_yaml_path'
,
type
=
str
,
help
=
"Ops version yaml file."
)
parser
.
add_argument
(
"--template_path"
,
type
=
str
,
help
=
"JinJa2 template file Path."
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
help
=
"Output path."
)
args
=
parser
.
parse_args
()
generate
(
args
.
api_prim_yaml_path
,
args
.
api_phi_yaml_path
,
args
.
api_phi_legacy_yaml_path
,
args
.
api_compat_yaml_path
,
args
.
api_version_yaml_path
,
args
.
template_path
,
args
.
output_path
,
)
paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.tpl
0 → 100644
浏览文件 @
82cf1fad
{
%
from
"utils.tpl"
import
static_prim_api
%
}
// Generated by /paddle/fluid/prim/api/auto_code_generated/static_gen.py.
// DO NOT EDIT!
#include
<string.h>
#include
<memory>
#include
<sstream>
#include
<string>
#include
<unordered_set>
#include
<vector>
#include
<algorithm>
#include
<tuple>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace prim {
{% for api in apis %}
{
{
static_prim_api
(
api
)
}
}
{% endfor %}
} // namespace prim
} // namespace paddle
paddle/fluid/prim/api/auto_code_generated/template/utils.tpl
0 → 100644
浏览文件 @
82cf1fad
{% macro static_prim_api(api) %}
{%- set fluid_name = api.op_name -%}
{%- set phi_name = api.name -%}
{%- set inputs = api.inputs -%}
{%- set outputs = api.outputs|trip_intermediate -%}
{
#
-
ignore
intermediate
output
-
#
}
{%- set attrs = api.attrs -%}
{%- set output_names = [] -%}
{%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%}
{
{
static_prim_api_sig
(
phi_name
,
inputs
,
outputs
,
attrs
)
}
} {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("{
{
fluid_name
}
}");
{% filter indent(2, True) %}
{% for input in inputs %}
{
{
static_prim_api_input
(
input
)
}
}
{% endfor %}
{% for output in outputs %}
{
{
static_prim_api_output
(
output
)
}
}
{% endfor %}
{% for attr in attrs %}
{
{
static_prim_api_attr
(
attr
)
}
}
{% endfor %}
{% endfilter %}
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
{% if outputs|length > 1 %}
return std::make_tuple{
{
sequence
(
'('
,
')'
,
', '
,
output_names
)
}
};
{% elif outputs|length == 1 %}
return {
{
outputs
[
0
].
name
}
};
{% else %}
{
#
-
render
nothing
-
#
}
{% endif %}
}
{% endmacro %}
{%- macro static_prim_api_sig(name, inputs, outputs, attrs) -%}
template
<>
{
{
static_prim_api_sig_ret
(
outputs
)
}
} {
{
name
}
}
<DescTensor>
({
{
static_prim_api_sig_params
(
inputs
,
attrs
)
}
})
{%- endmacro %}
{%- macro static_prim_api_sig_params(inputs, attrs) -%}
{%- set input_params = [] -%}
{%- for i in inputs -%} {%- do input_params.append(i.typename|to_paddle_input_type(i.optional)~' '~i.name) -%} {%- endfor -%}
{%- set attr_params = [] -%}
{%- for i in attrs -%} {%- do attr_params.append(i.typename|to_paddle_attr_type~' '~i.name) -%} {%- endfor -%}
{
{
sequence
(
''
,
''
,
', '
,
input_params
)
}
}
{%- if attr_params|length > 0 -%} {
{
", "
}
} {%- endif -%}
{
#
-
append
comma
between
inputs
and
attrs
-
#
}
{
{
sequence
(
''
,
''
,
', '
,
attr_params
)
}
}
{%- endmacro -%}
{%- macro static_prim_api_sig_ret(outputs) -%}
{%- set names = [] -%}
{%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type) -%} {%- endfor -%}
{%- if names|length > 1 -%}
std::tuple
<
{
{
sequence
(
''
,
''
,
', '
,
names
)
}
}
>
{%- else -%}
{
{
names
[
0
]
}
}
{%- endif -%}
{%- endmacro -%}
{% macro static_prim_api_input(input) %}
{%- if input.optional -%}
{
{
static_prim_api_input_optional
(
input
)
}
}
{%- else -%}
{
{
static_prim_api_input_without_optional
(
input
)
}
}
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_input_optional(input) -%}
{%- if input.typename=='Tensor[]' -%}
{
#
-
render
the
input
of
type
paddle
::
optional
<
std
::
Vector
<
Tensor
>>
-
#
}
if ({
{
input
.
name
}
}) {
std::vector
<std::string>
{
{
input
.
name
}
}_names;
std::transform({
{
input
.
name
}
}.get().begin(), {
{
input
.
name
}
}.get().end(), {
{
input
.
name
}
}_names.begin(), [](const Tensor
&
t) {
return std::static_pointer_cast
<prim::DescTensor>
(t.impl())->Name();
});
op->SetInput("{
{
input
.
fluid_name
|
to_pascal
}
}", {
{
input
.
name
}
}_names);
}
{%- else -%}
if ({
{
input
.
name
}
}) {
op->SetInput("{
{
input
.
fluid_name
|
to_pascal
}
}",
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>(
{{
input
.
name
}}
->
impl
())->
Name
()
}
);
}
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_input_without_optional(input) -%}
{%- if input.typename is tensor_sequence -%}
{
#
-
render
the
input
of
type
std
::
Vector
<
Tensor
>
-
#
}
std::vector
<std::string>
{
{
input
.
name
}
}_names;
std::transform({
{
input
.
name
}
}.begin(), {
{
input
.
name
}
}.end(), {
{
input
.
name
}
}_names.begin(), [](const Tensor
&
t) {
return std::static_pointer_cast
<prim::DescTensor>
(t.impl())->Name();
});
op->SetInput("{
{
input
.
fluid_name
|
to_pascal
}
}", {
{
input
.
name
}
}_names);
{%- else -%}
op->SetInput("{
{
input
.
fluid_name
|
to_pascal
}
}",
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>(
{{
input
.
name
}}
.
impl
())->
Name
()
}
);
{%- endif -%}
{%- endmacro -%}
{% macro static_prim_api_output(output) %}
{%- if output.optional -%}
{
{
static_prim_api_output_optional
(
output
)
}
}
{%- else -%}
{
{
static_prim_api_output_without_optional
(
output
)
}
}
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_output_without_optional(output) -%}
{%- if output.typename is tensor_sequence -%}
{
#
-
render
the
output
of
type
std
::
Vector
<
Tensor
>
-
#
}
std::vector
<Tensor>
{
{
output
.
name
}
};
std::vector
<std::string>
{
{
output
.
name
}
}_names;
for (auto i=0; i
<
{
{
output
.
size
}
};
i
++)
{
auto
tmp =
empty<DescTensor
>
({}, phi::DataType::FLOAT32, paddle::Place());
{
{
output
.
name
}
}.push_back(tmp);
{
{
output
.
name
}
}_names.push_back(std::static_pointer_cast
<prim::DescTensor>
(tmp.impl())->Name());
}
op->SetOutput("{
{
output
.
fluid_name
|
to_pascal
}
}", {
{
output
.
name
}
}_names);
{%- else -%}
auto {
{
output
.
name
}
} = empty
<DescTensor>
({}, phi::DataType::FLOAT32, paddle::Place());
op->SetOutput("{
{
output
.
fluid_name
|
to_pascal
}
}",
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>(
{{
output
.
name
}}
.
impl
())->
Name
()
}
);
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_output_optional(output) -%}
// TODO(cxxly): Render optional output
{%- endmacro -%}
{% macro static_prim_api_attr(attr) %}
op->SetAttr("{
{
attr
.
fluid_name
}
}", {
{
phi_attr_to_fluid
(
attr
)
}
});
{%- endmacro %}
{%- macro phi_attr_to_fluid(attr) -%}
{%- if attr.typename is intarray -%}
{
{
int_array_to_fluid
(
attr
.
name
,
attr
.
typename
,
attr
.
fluid_name
,
attr
.
data_type
)
}
}
{%- elif attr.typename is scalar -%}
{
{
scalar_to_fluid
(
attr
.
name
,
attr
.
typename
,
attr
.
fluid_name
,
attr
.
data_type
)
}
}
{%- elif attr.typename is datatype -%}
{
{
datatype_to_fluid
(
attr
.
name
,
attr
.
typename
,
attr
.
fluid_name
,
attr
.
data_type
)
}
}
{%- else -%}
{
{
attr
.
name
}
}
{%- endif -%}
{%- endmacro %}
{%- macro int_array_to_fluid(src_name, src_type, dst_name, dst_type) -%}
{%- if dst_type=='std::vector
<int>
' -%}
unsafe_vector_cast
<int64_t
,
int
>
({
{
src_name
}
}.GetData())
{%- else -%}
{
{
src_name
}
}.GetData()
{%- endif -%}
{%- endmacro -%}
{%- macro scalar_to_fluid(src_name, src_type, dst_name, dst_type) -%}
{
{
src_name
}
}.to
<
{
{
dst_type
}
}
>
()
{%- endmacro -%}
{%- macro datatype_to_fluid(src_name, src_type, dst_name, dst_type) -%}
paddle::framework::TransToProtoVarType({
{
src_name
}
})
{%- endmacro -%}
{%- macro sequence(lsymbol, rsymbol, delimiter, items) -%}
{
{
lsymbol
}
}{%- for item in items -%}{
{
item
}
}{
{
delimiter
if
not
loop
.
last
else
""
}
}{%- endfor -%}{
{
rsymbol
}
}
{%- endmacro -%}
paddle/fluid/prim/api/generated_prim/CMakeLists.txt
浏览文件 @
82cf1fad
if
(
NOT
(
NOT WITH_PYTHON AND ON_INFER
))
if
(
NOT
(
NOT WITH_PYTHON AND ON_INFER
))
cc_library
(
cc_library
(
eager_prim_api
generated_
eager_prim_api
SRCS eager_prim_api.cc
SRCS eager_prim_api.cc
DEPS final_dygraph_function eager_utils
)
DEPS final_dygraph_function eager_utils
)
endif
()
endif
()
cc_library
(
generated_static_prim_api
SRCS static_prim_api.cc
DEPS proto_desc static_utils
)
paddle/fluid/prim/api/manual_prim/CMakeLists.txt
浏览文件 @
82cf1fad
add_subdirectory
(
utils
)
add_subdirectory
(
utils
)
if
(
NOT
(
NOT WITH_PYTHON AND ON_INFER
))
cc_library
(
manual_eager_prim_api
SRCS eager_prim_api.cc
DEPS final_dygraph_function eager_utils
)
endif
()
cc_library
(
cc_library
(
static_prim_api
manual_
static_prim_api
SRCS static_prim_api.cc
SRCS static_prim_api.cc
DEPS proto_desc static_utils
)
DEPS proto_desc static_utils
)
paddle/fluid/prim/api/manual_prim/eager_prim_api.cc
0 → 100644
浏览文件 @
82cf1fad
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
namespace
paddle
{
namespace
prim
{
template
<
>
Tensor
reshape
<
Tensor
>
(
const
Tensor
&
x
,
const
IntArray
&
shape
)
{
VLOG
(
4
)
<<
"Eager Prim API reshape_ad_func call"
;
return
::
reshape_ad_func
(
x
,
shape
);
}
template
<
>
Tensor
full
<
Tensor
>
(
const
IntArray
&
shape
,
const
Scalar
&
value
,
DataType
dtype
,
const
Place
&
place
)
{
VLOG
(
4
)
<<
"Eager Prim API full_ad_func call"
;
return
::
full_ad_func
(
shape
,
value
,
dtype
,
place
);
}
}
// namespace prim
}
// namespace paddle
paddle/fluid/prim/api/manual_prim/prim_manual_api.h
浏览文件 @
82cf1fad
...
@@ -14,14 +14,27 @@
...
@@ -14,14 +14,27 @@
#pragma once
#pragma once
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/optional.h"
// TODO(jiabin): Make this Header only for handwritten api, instead of include
// prim_generated_api.h
namespace
paddle
{
namespace
paddle
{
namespace
prim
{}
// namespace prim
namespace
prim
{
using
Tensor
=
paddle
::
experimental
::
Tensor
;
using
Scalar
=
paddle
::
experimental
::
Scalar
;
using
IntArray
=
paddle
::
experimental
::
IntArray
;
using
DataType
=
paddle
::
experimental
::
DataType
;
template
<
typename
T
>
Tensor
reshape
(
const
Tensor
&
x
,
const
IntArray
&
shape
);
template
<
typename
T
>
Tensor
full
(
const
IntArray
&
shape
,
const
Scalar
&
value
,
DataType
dtype
=
DataType
::
FLOAT32
,
const
Place
&
place
=
CPUPlace
());
}
// namespace prim
}
// namespace paddle
}
// namespace paddle
paddle/fluid/prim/api/manual_prim/static_prim_api.cc
浏览文件 @
82cf1fad
...
@@ -38,111 +38,18 @@ namespace paddle {
...
@@ -38,111 +38,18 @@ namespace paddle {
namespace
prim
{
namespace
prim
{
template
<
>
template
<
>
Tensor
pow
<
DescTensor
>
(
const
Tensor
&
x
,
const
Scalar
&
y
)
{
Tensor
reshape
<
DescTensor
>
(
const
Tensor
&
x
,
const
IntArray
&
shape
)
{
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"pow"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
op
->
SetAttr
(
"factor"
,
y
.
to
<
float
>
());
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
op
->
InferShape
(
*
block
);
return
out
;
}
template
<
>
Tensor
scale
<
DescTensor
>
(
const
Tensor
&
x
,
const
Scalar
&
scale
,
float
bias
,
bool
bias_after_scale
)
{
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"scale"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
op
->
SetAttr
(
"scale"
,
scale
.
to
<
float
>
());
op
->
SetAttr
(
"bias"
,
bias
);
op
->
SetAttr
(
"bias_after_scale"
,
bias_after_scale
);
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
op
->
InferShape
(
*
block
);
return
out
;
}
template
<
>
Tensor
multiply
<
DescTensor
>
(
const
Tensor
&
x
,
const
Tensor
&
y
)
{
// Grad infershape
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"elementwise_mul"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetInput
(
"Y"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
y
.
impl
())
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
op
->
InferShape
(
*
block
);
return
out
;
}
template
<
>
Tensor
unsqueeze
<
DescTensor
>
(
const
Tensor
&
x
,
const
IntArray
&
axis
)
{
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"unsqueeze2"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
std
::
vector
<
int
>
new_shape
(
axis
.
GetData
().
begin
(),
axis
.
GetData
().
end
());
op
->
SetAttr
(
"axes"
,
new_shape
);
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
op
->
InferShape
(
*
block
);
return
out
;
}
template
<
>
Tensor
expand
<
DescTensor
>
(
const
Tensor
&
x
,
const
IntArray
&
shape
)
{
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"expand_v2"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
std
::
vector
<
int
>
new_shape
(
shape
.
GetData
().
begin
(),
shape
.
GetData
().
end
());
op
->
SetAttr
(
"shape"
,
new_shape
);
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
return
out
;
}
template
<
>
Tensor
divide
<
DescTensor
>
(
const
Tensor
&
x
,
const
Tensor
&
y
)
{
// Grad infershape
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"elementwise_div"
);
// TODO(cxxly): Fix test_resnet_prim_cinn error when SetType("reshape2")
op
->
SetType
(
"reshape"
);
op
->
SetInput
(
"X"
,
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetInput
(
"Y"
,
// Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
y
.
impl
())
->
Name
()}
);
auto
out
=
empty
<
DescTensor
>
({},
x
.
dtype
(),
paddle
::
Place
()
);
op
->
SetOutput
(
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
op
->
SetAttr
(
"shape"
,
unsafe_vector_cast
<
int64_t
,
int
>
(
shape
.
GetData
()));
op
->
CheckAttrs
();
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
op
->
InferVarType
(
block
);
op
->
InferShape
(
*
block
);
op
->
InferShape
(
*
block
);
...
@@ -186,70 +93,5 @@ Tensor full<DescTensor>(const IntArray& shape,
...
@@ -186,70 +93,5 @@ Tensor full<DescTensor>(const IntArray& shape,
return
out
;
return
out
;
}
}
template
<
>
Tensor
sum
<
DescTensor
>
(
const
Tensor
&
x
,
const
IntArray
&
axis
,
DataType
dtype
,
bool
keepdim
)
{
// Grad infershape
Tensor
out
=
empty
<
DescTensor
>
({},
dtype
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"reduce_sum"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
std
::
vector
<
int
>
res
;
for
(
auto
value
:
axis
.
GetData
())
{
res
.
push_back
(
static_cast
<
int
>
(
value
));
}
op
->
SetAttr
(
"dim"
,
res
);
op
->
SetAttr
(
"keep_dim"
,
keepdim
);
op
->
SetAttr
(
"dtype"
,
paddle
::
framework
::
TransToProtoVarType
(
dtype
));
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
op
->
InferShape
(
*
block
);
return
out
;
}
template
<
>
Tensor
reshape
<
DescTensor
>
(
const
Tensor
&
x
,
const
IntArray
&
shape
)
{
// Grad infershape
Tensor
out
=
empty
<
DescTensor
>
({},
x
.
dtype
(),
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"reshape"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
std
::
vector
<
int
>
res
;
for
(
auto
value
:
shape
.
GetData
())
{
// TODO(jiabin): This cast is not safe for now, find a way to handle this.
res
.
push_back
(
static_cast
<
int
>
(
value
));
}
op
->
SetAttr
(
"shape"
,
res
);
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
op
->
InferShape
(
*
block
);
return
out
;
}
template
<
>
Tensor
exp
<
DescTensor
>
(
const
Tensor
&
x
)
{
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"exp"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
op
->
InferShape
(
*
block
);
return
out
;
}
}
// namespace prim
}
// namespace prim
}
// namespace paddle
}
// namespace paddle
paddle/fluid/prim/api/manual_prim/utils/utils.h
浏览文件 @
82cf1fad
...
@@ -78,5 +78,11 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
...
@@ -78,5 +78,11 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
return
get_reduce_dims_from_out
(
out_dims
,
x_dims
);
return
get_reduce_dims_from_out
(
out_dims
,
x_dims
);
}
}
// TODO(cxxly): Check and throws InvalidCastException when overflow.
template
<
typename
SRC_T
,
typename
DST_T
>
static
std
::
vector
<
DST_T
>
unsafe_vector_cast
(
const
std
::
vector
<
SRC_T
>&
src
)
{
std
::
vector
<
DST_T
>
dst
(
src
.
begin
(),
src
.
end
());
return
dst
;
}
}
// namespace prim
}
// namespace prim
}
// namespace paddle
}
// namespace paddle
paddle/phi/api/yaml/op_compat.yaml
浏览文件 @
82cf1fad
...
@@ -231,6 +231,16 @@
...
@@ -231,6 +231,16 @@
-
op
:
concat
-
op
:
concat
backward
:
concat_grad
backward
:
concat_grad
inputs
:
x
:
X
outputs
:
out
:
Out
attrs
:
axis
:
axis
scalar
:
axis
:
data_type
:
int
tensor_name
:
AxisTensor
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
,
bool use_quantizer = false
,
str mkldnn_data_type = "float32"
]
attrs
:
[
bool use_mkldnn = false
,
bool use_quantizer = false
,
str mkldnn_data_type = "float32"
]
...
@@ -395,6 +405,10 @@
...
@@ -395,6 +405,10 @@
-
op
:
divide (elementwise_div)
-
op
:
divide (elementwise_div)
backward
:
divide_grad (elementwise_div)
backward
:
divide_grad (elementwise_div)
inputs
:
{
x
:
X
,
y
:
Y
}
outputs
:
out
:
Out
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
,
str x_data_format = ""
,
str y_data_format = ""
,
str mkldnn_data_type = "float32"
,
attrs
:
[
bool use_mkldnn = false
,
str x_data_format = ""
,
str y_data_format = ""
,
str mkldnn_data_type = "float32"
,
bool use_quantizer = false
,
float Scale_x = 1.0f
,
float Scale_y = 1.0f
,
float Scale_out = 1.0f
]
bool use_quantizer = false
,
float Scale_x = 1.0f
,
float Scale_y = 1.0f
,
float Scale_out = 1.0f
]
...
@@ -486,6 +500,17 @@
...
@@ -486,6 +500,17 @@
-
op
:
expand (expand_v2)
-
op
:
expand (expand_v2)
backward
:
expand_grad (expand_v2_grad)
backward
:
expand_grad (expand_v2_grad)
inputs
:
x
:
X
attrs
:
shape
:
shape
outputs
:
out
:
Out
int_array
:
shape
:
data_type
:
int
tensor_name
:
Shape
tensors_name
:
expand_shapes_tensor
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
,
str mkldnn_data_type = "float32"
]
attrs
:
[
bool use_mkldnn = false
,
str mkldnn_data_type = "float32"
]
...
@@ -898,6 +923,12 @@
...
@@ -898,6 +923,12 @@
-
op
:
matmul (matmul_v2)
-
op
:
matmul (matmul_v2)
backward
:
matmul_grad (matmul_v2_grad)
backward
:
matmul_grad (matmul_v2_grad)
inputs
:
{
x
:
X
,
y
:
Y
}
attrs
:
{
transpose_x
:
trans_x
,
transpose_y
:
trans_y
}
outputs
:
out
:
Out
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
,
'
int[]
fused_reshape_Out
=
{}'
,
'
int[]
fused_transpose_Out
=
{}'
,
attrs
:
[
bool use_mkldnn = false
,
'
int[]
fused_reshape_Out
=
{}'
,
'
int[]
fused_transpose_Out
=
{}'
,
str mkldnn_data_type = "float32"
,
'
int[]
fused_reshape_X
=
{}'
,
'
int[]
fused_reshape_Y
=
{}'
,
str mkldnn_data_type = "float32"
,
'
int[]
fused_reshape_X
=
{}'
,
'
int[]
fused_reshape_Y
=
{}'
,
...
@@ -915,6 +946,20 @@
...
@@ -915,6 +946,20 @@
outputs
:
outputs
:
out
:
Out
out
:
Out
-
op
:
max (reduce_max)
backward
:
max_grad (reduce_max_grad)
inputs
:
x
:
X
attrs
:
{
axis
:
dim
,
keepdim
:
keep_dim
}
outputs
:
out
:
Out
int_array
:
axis
:
data_type
:
int
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
maximum (elementwise_max)
-
op
:
maximum (elementwise_max)
backward
:
maximum_grad (elementwise_max_grad)
backward
:
maximum_grad (elementwise_max_grad)
extra
:
extra
:
...
@@ -981,6 +1026,10 @@
...
@@ -981,6 +1026,10 @@
-
op
:
multiply (elementwise_mul)
-
op
:
multiply (elementwise_mul)
backward
:
multiply_grad (elementwise_mul_grad)
backward
:
multiply_grad (elementwise_mul_grad)
inputs
:
{
x
:
X
,
y
:
Y
}
outputs
:
out
:
Out
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
,
str x_data_format = ""
,
str y_data_format = ""
,
str mkldnn_data_type = "float32"
,
attrs
:
[
bool use_mkldnn = false
,
str x_data_format = ""
,
str y_data_format = ""
,
str mkldnn_data_type = "float32"
,
bool use_quantizer = false
,
float Scale_x = 1.0f
,
float Scale_y = 1.0f
,
float Scale_out = 1.0f
]
bool use_quantizer = false
,
float Scale_x = 1.0f
,
float Scale_y = 1.0f
,
float Scale_out = 1.0f
]
...
@@ -1079,6 +1128,20 @@
...
@@ -1079,6 +1128,20 @@
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
,
str mkldnn_data_type = "float32"
,
bool is_test = false
]
attrs
:
[
bool use_mkldnn = false
,
str mkldnn_data_type = "float32"
,
bool is_test = false
]
-
op
:
prod (reduce_prod)
backward
:
prod_grad (reduce_prod_grad)
inputs
:
x
:
X
attrs
:
{
dims
:
dim
,
keep_dim
:
keep_dim
}
outputs
:
out
:
Out
int_array
:
axis
:
data_type
:
int
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
put_along_axis
-
op
:
put_along_axis
backward
:
put_along_axis_grad
backward
:
put_along_axis_grad
inputs
:
inputs
:
...
@@ -1133,11 +1196,6 @@
...
@@ -1133,11 +1196,6 @@
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
]
attrs
:
[
bool use_mkldnn = false
]
-
op
:
reduce_max
backward
:
reduce_max_grad
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
reduce_mean
-
op
:
reduce_mean
backward
:
reduce_mean_grad
backward
:
reduce_mean_grad
extra
:
extra
:
...
@@ -1148,16 +1206,6 @@
...
@@ -1148,16 +1206,6 @@
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
]
attrs
:
[
bool use_mkldnn = false
]
-
op
:
reduce_prod
backward
:
reduce_prod_grad
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
reduce_sum
backward
:
reduce_sum_grad
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
relu
-
op
:
relu
backward
:
relu_grad, relu_double_grad (relu_grad_grad)
backward
:
relu_grad, relu_double_grad (relu_grad_grad)
inputs
:
inputs
:
...
@@ -1186,6 +1234,20 @@
...
@@ -1186,6 +1234,20 @@
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
,
bool use_cudnn = false
]
attrs
:
[
bool use_mkldnn = false
,
bool use_cudnn = false
]
-
op
:
reshape (reshape2)
backward
:
reshape_grad (reshape2_grad)
inputs
:
x
:
X
outputs
:
out
:
Out
int_array
:
shape
:
data_type
:
int
tensor_name
:
Shape
tensors_name
:
ShapeTensor
extra
:
attrs
:
[
bool use_mkldnn = false
,
str mkldnn_data_type = "float32"
,
bool use_quantizer = false
]
-
op
:
roll
-
op
:
roll
backward
:
roll_grad
backward
:
roll_grad
inputs
:
inputs
:
...
@@ -1216,6 +1278,10 @@
...
@@ -1216,6 +1278,10 @@
attrs
:
[
bool use_mkldnn = false
,
bool use_cudnn = false
]
attrs
:
[
bool use_mkldnn = false
,
bool use_cudnn = false
]
-
op
:
scale
-
op
:
scale
inputs
:
x
:
X
outputs
:
out
:
Out
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
]
attrs
:
[
bool use_mkldnn = false
]
...
@@ -1437,10 +1503,28 @@
...
@@ -1437,10 +1503,28 @@
-
op
:
subtract (elementwise_sub)
-
op
:
subtract (elementwise_sub)
backward
:
subtract_grad (elementwise_sub_grad)
backward
:
subtract_grad (elementwise_sub_grad)
inputs
:
{
x
:
X
,
y
:
Y
}
outputs
:
out
:
Out
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
,
str x_data_format = ""
,
str y_data_format = ""
,
str mkldnn_data_type = "float32"
,
attrs
:
[
bool use_mkldnn = false
,
str x_data_format = ""
,
str y_data_format = ""
,
str mkldnn_data_type = "float32"
,
bool use_quantizer = false
,
float Scale_x = 1.0f
,
float Scale_y = 1.0f
,
float Scale_out = 1.0f
]
bool use_quantizer = false
,
float Scale_x = 1.0f
,
float Scale_y = 1.0f
,
float Scale_out = 1.0f
]
-
op
:
sum (reduce_sum)
backward
:
(sum_grad) reduce_sum_grad
inputs
:
{
x
:
X
}
attrs
:
{
axis
:
dim
,
keepdim
:
keep_dim
,
dtype
:
out_dtype
}
outputs
:
out
:
Out
int_array
:
axis
:
data_type
:
int
extra
:
attrs
:
[
bool use_mkldnn = false
]
-
op
:
svd
-
op
:
svd
backward
:
svd_grad
backward
:
svd_grad
inputs
:
inputs
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录