未验证 提交 f369b2b1 编写于 作者: C Chen Weihang 提交者: GitHub

[PHI decoupling] Move fluid op generator into fluid (#47714)

* move fluid op generator into fluid

* remove parsed op

* resolve sig undef error

* append python interp find logic

* remove dup code
上级 0a051297
...@@ -75,8 +75,7 @@ paddle/fluid/operators/generated_op.cc ...@@ -75,8 +75,7 @@ paddle/fluid/operators/generated_op.cc
paddle/fluid/operators/generated_sparse_op.cc paddle/fluid/operators/generated_sparse_op.cc
paddle/phi/ops/compat/generated_sig.cc paddle/phi/ops/compat/generated_sig.cc
paddle/phi/ops/compat/generated_sparse_sig.cc paddle/phi/ops/compat/generated_sparse_sig.cc
paddle/phi/api/yaml/parsed_apis/ paddle/fluid/operators/generator/parsed_ops/
python/paddle/utils/code_gen/
paddle/fluid/pybind/tmp_eager_op_function_impl.h paddle/fluid/pybind/tmp_eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h
......
...@@ -111,7 +111,7 @@ function(kernel_declare TARGET_LIST) ...@@ -111,7 +111,7 @@ function(kernel_declare TARGET_LIST)
endfunction() endfunction()
function(append_op_util_declare TARGET) function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content) file(READ ${TARGET} target_content)
string( string(
REGEX REGEX
MATCH MATCH
...@@ -134,13 +134,10 @@ function(register_op_utils TARGET_NAME) ...@@ -134,13 +134,10 @@ function(register_op_utils TARGET_NAME)
cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}" cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN}) "${multiValueArgs}" ${ARGN})
file( file(GLOB SIGNATURES "${PADDLE_SOURCE_DIR}/paddle/phi/ops/compat/*_sig.cc")
GLOB SIGNATURES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"*_sig.cc")
foreach(target ${SIGNATURES}) foreach(target ${SIGNATURES})
append_op_util_declare(${target}) append_op_util_declare(${target})
list(APPEND utils_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${target}) list(APPEND utils_srcs ${target})
endforeach() endforeach()
cc_library( cc_library(
......
...@@ -503,7 +503,8 @@ if(WITH_XPU) ...@@ -503,7 +503,8 @@ if(WITH_XPU)
phi_utils phi_utils
kernel_factory kernel_factory
infershape_utils infershape_utils
op_utils) op_utils
op_compat_infos)
else() else()
cc_library( cc_library(
operator operator
...@@ -528,7 +529,8 @@ else() ...@@ -528,7 +529,8 @@ else()
phi_utils phi_utils
kernel_factory kernel_factory
infershape_utils infershape_utils
op_utils) op_utils
op_compat_infos)
endif() endif()
cc_test( cc_test(
......
include(operators) include(operators)
add_subdirectory(generator)
# solve "math constants not defined" problems caused by the order of inclusion # solve "math constants not defined" problems caused by the order of inclusion
# of <cmath> and the definition of macro _USE_MATH_DEFINES # of <cmath> and the definition of macro _USE_MATH_DEFINES
add_definitions(-D_USE_MATH_DEFINES) add_definitions(-D_USE_MATH_DEFINES)
......
# phi auto cmake utils
include(phi)
# set yaml file path
set(op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/ops.yaml)
set(legacy_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/legacy_ops.yaml)
set(bw_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/backward.yaml)
set(legacy_bw_op_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/legacy_backward.yaml)
set(sparse_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_ops.yaml)
set(sparse_bw_op_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_backward.yaml)
if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED)
endif()
# install extra dependencies
if(${PYTHON_VERSION_STRING} VERSION_LESS "3.6.2")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml
typing-extensions>=4.1.1 jinja2==2.11.3)
else()
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml jinja2
typing-extensions)
endif()
# parse ops
set(parsed_op_dir
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_sparse_ops_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc)
set(generated_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc)
set(generated_sparse_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc)
message(
"parse op yamls:
- ${op_yaml_file}
- ${legacy_op_yaml_file}
- ${bw_op_yaml_file}
- ${legacy_bw_op_yaml_file}")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir}
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${op_yaml_file}
--output_path ./parsed_ops/ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${legacy_op_yaml_file}
--output_path ./parsed_ops/legacy_ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${bw_op_yaml_file}
--output_path ./parsed_ops/backward_ops.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${legacy_bw_op_yaml_file}
--output_path ./parsed_ops/legacy_backward_ops.parsed.yaml --backward
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_op_yaml_file}
--output_path ./parsed_ops/sparse_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_bw_op_yaml_file}
--output_path ./parsed_ops/sparse_backward.parsed.yaml --backward
RESULTS_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "op yaml parsing failed, exiting.")
endif()
endforeach()
# validation of op yamls
message("validate op yaml:
- ${parsed_op_dir}/ops.parsed.yaml
- ${parsed_op_dir}/backward_ops.parsed.yaml")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
./parsed_ops/ops.parsed.yaml ./parsed_ops/legacy_ops.parsed.yaml
--backward_yaml_paths ./parsed_ops/backward_ops.parsed.yaml
./parsed_ops/legacy_backward_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
./parsed_ops/sparse_ops.parsed.yaml --backward_yaml_paths
./parsed_ops/sparse_backward.parsed.yaml
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
endforeach()
# code generation for op, op makers, and argument mapping functions
message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
)
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
${PYTHON_EXECUTABLE} generate_op.py --ops_yaml_path
./parsed_ops/ops.parsed.yaml --backward_yaml_path
./parsed_ops/backward_ops.parsed.yaml --op_version_yaml_path
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
COMMAND
${PYTHON_EXECUTABLE} generate_sparse_op.py --ops_yaml_path
./parsed_ops/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_ops/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_op_path}.tmp" "${generated_op_path}")
message("copy if different ${generated_op_path}.tmp ${generated_op_path}")
elseif(EXISTS "${generated_op_path}.tmp")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_op_path}.tmp"
"${generated_op_path}")
message("copy ${generated_op_path}.tmp ${generated_op_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f "${generated_op_path}")
message("remove ${generated_op_path}")
endif()
if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS
"${generated_sparse_ops_path}")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}")
message(
"copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}"
)
elseif(EXISTS "${generated_sparse_ops_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp"
"${generated_sparse_ops_path}")
message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_ops_path}")
message("remove ${generated_sparse_ops_path}")
endif()
if(EXISTS "${generated_argument_mapping_path}.tmp"
AND EXISTS "${generated_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy if different ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
elseif(EXISTS "${generated_argument_mapping_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_argument_mapping_path}")
message("remove ${generated_argument_mapping_path}")
endif()
if(EXISTS "${generated_sparse_argument_mapping_path}.tmp"
AND EXISTS "${generated_sparse_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_argument_mapping_path}")
message("remove ${generated_sparse_argument_mapping_path}")
endif()
# op extra info file
set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/ops_extra_info_gen.py)
set(op_compat_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc)
# generate ops extra info
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path
${op_compat_yaml_file} --ops_extra_info_path ${ops_extra_info_file})
message("generate ${ops_extra_info_file}")
set(op_utils_header
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h.tmp
CACHE INTERNAL "op_args_fns.cc file")
set(op_utils_header_final
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h)
file(
WRITE ${op_utils_header}
"// Generated by the paddle/fluid/operators/generator/CMakeLists.txt. DO NOT EDIT!\n\n"
)
file(APPEND ${op_utils_header}
"#include \"paddle/phi/core/compat/op_utils.h\"\n\n")
# Automatically generate the registration code of all arg map functions
# and compile the corresponding target to avoid frequent code conflicts
# when writing to same file
register_op_utils(op_compat_infos DEPS op_utils)
copy_if_different(${op_utils_header} ${op_utils_header_final})
...@@ -20,35 +20,35 @@ import yaml ...@@ -20,35 +20,35 @@ import yaml
from parse_utils import cross_validate, to_named_dict from parse_utils import cross_validate, to_named_dict
def main(forward_api_yaml_paths, backward_api_yaml_paths): def main(forward_op_yaml_paths, backward_op_yaml_paths):
apis = {} ops = {}
for api_yaml_path in chain(forward_api_yaml_paths, backward_api_yaml_paths): for op_yaml_path in chain(forward_op_yaml_paths, backward_op_yaml_paths):
with open(api_yaml_path, "rt", encoding="utf-8") as f: with open(op_yaml_path, "rt", encoding="utf-8") as f:
api_list = yaml.safe_load(f) op_list = yaml.safe_load(f)
if api_list is not None: if op_list is not None:
apis.update(to_named_dict((api_list))) ops.update(to_named_dict((op_list)))
cross_validate(apis) cross_validate(ops)
if __name__ == "__main__": if __name__ == "__main__":
current_dir = Path(__file__).parent / "temp" current_dir = Path(__file__).parent / "temp"
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Parse api yaml into canonical format." description="Parse op yaml into canonical format."
) )
parser.add_argument( parser.add_argument(
'--forward_yaml_paths', '--forward_yaml_paths',
type=str, type=str,
nargs='+', nargs='+',
default=str(current_dir / "api.parsed.yaml"), default=str(current_dir / "op .parsed.yaml"),
help="forward api yaml file.", help="forward op yaml file.",
) )
parser.add_argument( parser.add_argument(
'--backward_yaml_paths', '--backward_yaml_paths',
type=str, type=str,
nargs='+', nargs='+',
default=str(current_dir / "backward_api.parsed.yaml"), default=str(current_dir / "backward_op .parsed.yaml"),
help="backward api yaml file.", help="backward op yaml file.",
) )
args = parser.parse_args() args = parser.parse_args()
......
...@@ -102,12 +102,12 @@ def to_pascal_case(s): ...@@ -102,12 +102,12 @@ def to_pascal_case(s):
def to_input_name(s): def to_input_name(s):
"""find input variable name in api yaml for higher order backward api. """find input variable name in op yaml for higher order backward op .
x -> dx x -> dx
x -> d2x x -> d2x
x -> d3x x -> d3x
NOTE: for first order backward api NOTE: for first order backward op
x -> x_grad x -> x_grad
is more common. is more common.
""" """
...@@ -137,16 +137,14 @@ def cartesian_prod_attrs(attrs): ...@@ -137,16 +137,14 @@ def cartesian_prod_attrs(attrs):
return combinations return combinations
def cartesian_prod_mapping(api): def cartesian_prod_mapping(op):
kernels = api["kernel"]["func"] kernels = op["kernel"]["func"]
inputs = [ inputs = [
x["name"] for x in api["inputs"] if x["name"] in api["kernel"]["param"] x["name"] for x in op["inputs"] if x["name"] in op["kernel"]["param"]
] ]
inputs = [to_opmaker_name_cstr(input) for input in inputs] inputs = [to_opmaker_name_cstr(input) for input in inputs]
attrs = cartesian_prod_attrs(api["attrs"]) attrs = cartesian_prod_attrs(op["attrs"])
outputs = [ outputs = [to_opmaker_name_cstr(output["name"]) for output in op["outputs"]]
to_opmaker_name_cstr(output["name"]) for output in api["outputs"]
]
def vec(items): def vec(items):
return "{" + ', '.join(items) + "}" return "{" + ', '.join(items) + "}"
......
...@@ -26,7 +26,7 @@ from filters import ( ...@@ -26,7 +26,7 @@ from filters import (
to_pascal_case, to_pascal_case,
) )
from tests import ( from tests import (
is_base_api, is_base_op,
is_vec, is_vec,
is_scalar, is_scalar,
is_initializer_list, is_initializer_list,
...@@ -52,7 +52,7 @@ env.filters["to_pascal_case"] = to_pascal_case ...@@ -52,7 +52,7 @@ env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.tests["base_api"] = is_base_api env.tests["base_op"] = is_base_op
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list env.tests["initializer_list"] = is_initializer_list
...@@ -60,65 +60,63 @@ env.tests["supports_inplace"] = supports_inplace ...@@ -60,65 +60,63 @@ env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer env.tests["supports_no_need_buffer"] = supports_no_need_buffer
def restruct_io(api): def restruct_io(op):
api["input_dict"] = to_named_dict(api["inputs"]) op["input_dict"] = to_named_dict(op["inputs"])
api["attr_dict"] = to_named_dict(api["attrs"]) op["attr_dict"] = to_named_dict(op["attrs"])
api["output_dict"] = to_named_dict(api["outputs"]) op["output_dict"] = to_named_dict(op["outputs"])
return api return op
SPARSE_OP_PREFIX = 'sparse_' SPARSE_OP_PREFIX = 'sparse_'
def main( def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
api_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path with open(op_yaml_path, "rt") as f:
): ops = yaml.safe_load(f)
with open(api_yaml_path, "rt") as f: ops = [restruct_io(op) for op in ops]
apis = yaml.safe_load(f) forward_op_dict = to_named_dict(ops)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f: with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f) backward_ops = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis] backward_ops = [restruct_io(op) for op in backward_ops]
backward_api_dict = to_named_dict(backward_apis) backward_op_dict = to_named_dict(backward_ops)
for api in apis: for op in ops:
api['op_name'] = SPARSE_OP_PREFIX + api['name'] op['op_name'] = SPARSE_OP_PREFIX + op['name']
api['name'] = api['op_name'] op['name'] = op['op_name']
if api["backward"] is not None: if op["backward"] is not None:
api["backward"] = SPARSE_OP_PREFIX + api["backward"] op["backward"] = SPARSE_OP_PREFIX + op["backward"]
for bw_api in backward_apis: for bw_op in backward_ops:
bw_api['op_name'] = SPARSE_OP_PREFIX + bw_api['name'] bw_op['op_name'] = SPARSE_OP_PREFIX + bw_op['name']
bw_api['name'] = bw_api['op_name'] bw_op['name'] = bw_op['op_name']
if 'invoke' in bw_api: if 'invoke' in bw_op:
bw_api['invoke']['args'] = [ bw_op['invoke']['args'] = [
param.strip() for param in bw_api['invoke']['args'].split(',') param.strip() for param in bw_op['invoke']['args'].split(',')
] ]
# prepare for invoke case # prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict) process_invoke_op(forward_op_dict, backward_op_dict)
for bw_api in backward_apis: for bw_op in backward_ops:
if 'invoke' in bw_api: if 'invoke' in bw_op:
if bw_api['invoke']['func'] in forward_api_dict: if bw_op['invoke']['func'] in forward_op_dict:
bw_api['invoke']['func'] = ( bw_op['invoke']['func'] = (
SPARSE_OP_PREFIX + bw_api['invoke']['func'] SPARSE_OP_PREFIX + bw_op['invoke']['func']
) )
# fill backward field for an api if another api claims it as forward # fill backward field for an op if another op claims it as forward
for name, backward_api in backward_api_dict.items(): for name, backward_op in backward_op_dict.items():
forward_name = backward_api["forward"]["name"] forward_name = backward_op["forward"]["name"]
if forward_name in backward_api_dict: if forward_name in backward_op_dict:
forward_api = backward_api_dict[forward_name] forward_op = backward_op_dict[forward_name]
if forward_api["backward"] is None: if forward_op["backward"] is None:
forward_api["backward"] = name forward_op["backward"] = name
forward_api["backward"] = SPARSE_OP_PREFIX + forward_api["backward"] forward_op["backward"] = SPARSE_OP_PREFIX + forward_op["backward"]
api_dict = {} op_dict = {}
api_dict.update(forward_api_dict) op_dict.update(forward_op_dict)
api_dict.update(backward_api_dict) op_dict.update(backward_op_dict)
if len(apis) == 0 and len(backward_apis) == 0: if len(ops) == 0 and len(backward_ops) == 0:
if os.path.isfile(output_op_path): if os.path.isfile(output_op_path):
os.remove(output_op_path) os.remove(output_op_path)
if os.path.isfile(output_arg_map_path): if os.path.isfile(output_arg_map_path):
...@@ -128,19 +126,19 @@ def main( ...@@ -128,19 +126,19 @@ def main(
op_template = env.get_template('sparse_op.c.j2') op_template = env.get_template('sparse_op.c.j2')
with open(output_op_path, "wt") as f: with open(output_op_path, "wt") as f:
msg = op_template.render( msg = op_template.render(
apis=apis, backward_apis=backward_apis, api_dict=api_dict ops=ops, backward_ops=backward_ops, op_dict=op_dict
) )
f.write(msg) f.write(msg)
ks_template = env.get_template('sparse_ks.c.j2') ks_template = env.get_template('sparse_ks.c.j2')
with open(output_arg_map_path, 'wt') as f: with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(apis=apis, backward_apis=backward_apis) msg = ks_template.render(ops=ops, backward_ops=backward_ops)
f.write(msg) f.write(msg)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate operator file from api yaml." description="Generate operator file from op yaml."
) )
parser.add_argument( parser.add_argument(
'--ops_yaml_path', type=str, help="parsed sparse ops yaml file." '--ops_yaml_path', type=str, help="parsed sparse ops yaml file."
......
...@@ -16,33 +16,33 @@ import argparse ...@@ -16,33 +16,33 @@ import argparse
import yaml import yaml
from parse_utils import parse_api_entry from parse_utils import parse_op_entry
def main(api_yaml_path, output_path, backward): def main(op_yaml_path, output_path, backward):
with open(api_yaml_path, "rt") as f: with open(op_yaml_path, "rt") as f:
apis = yaml.safe_load(f) ops = yaml.safe_load(f)
if apis is None: if ops is None:
apis = [] ops = []
else: else:
apis = [ ops = [
parse_api_entry(api, "backward_op" if backward else "op") parse_op_entry(op, "backward_op" if backward else "op")
for api in apis for op in ops
] ]
with open(output_path, "wt") as f: with open(output_path, "wt") as f:
yaml.safe_dump(apis, f, default_flow_style=None, sort_keys=False) yaml.safe_dump(ops, f, default_flow_style=None, sort_keys=False)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Parse api yaml into canonical format." description="Parse op yaml into canonical format."
) )
parser.add_argument('--api_yaml_path', type=str, help="api yaml file.") parser.add_argument('--op_yaml_path', type=str, help="op yaml file.")
parser.add_argument( parser.add_argument(
"--output_path", type=str, help="path to save parsed yaml file." "--output_path", type=str, help="path to save parsed yaml file."
) )
parser.add_argument("--backward", action="store_true", default=False) parser.add_argument("--backward", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
main(args.api_yaml_path, args.output_path, args.backward) main(args.op_yaml_path, args.output_path, args.backward)
{% from "operator_utils.c.j2" import name_map, register_name_map, register_base_kernel_name %} {% from "operator_utils.c.j2" import name_map, register_name_map, register_base_kernel_name %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. // this file is generated by paddle/phi/op/yaml/generator/generate_op.py, do not edit.
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
namespace phi { namespace phi {
{% for api in apis %} {% for op in ops %}
{% if api is base_api %} {% if op is base_op %}
{{name_map(api)}} {{name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% for api in backward_apis %} {% for op in backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{name_map(api)}} {{name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace phi } // namespace phi
{% for api in apis + backward_apis %} {% for op in ops + backward_ops %}
{% if api["name"] != api["op_name"] %} {% if op["name"] != op["op_name"] %}
{{register_base_kernel_name(api)}} {{register_base_kernel_name(op)}}
{% endif %} {% endif %}
{% if api is base_api %} {% if op is base_op %}
{{register_name_map(api)}} {{register_name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -18,32 +18,32 @@ namespace operators { ...@@ -18,32 +18,32 @@ namespace operators {
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
{% for api in apis %} {% for op in ops %}
{% if api is base_api %} {% if op is base_op %}
{{op_maker(api)}} {{op_maker(op)}}
{{operator(api)}} {{operator(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% for api in backward_apis %} {% for op in backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}} {{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
{{operator(api)}} {{operator(op)}}
{% else %} {% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}} {{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
{% for api in apis + backward_apis %} {% for op in ops + backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{register_op_with_components(api)}} {{register_op_with_components(op)}}
{{register_op_version(api)}} {{register_op_version(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{# ----------------------------- op maker ----------------------------------- #} {# ----------------------------- op maker ----------------------------------- #}
{% macro op_maker(api) %} {% macro op_maker(op) %}
{% set api_name = api["op_name"] %} {% set op_name = op["op_name"] %}
class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker { class {{op_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
{% filter indent(4, True) %} {% filter indent(4, True) %}
{% for input in api["inputs"] %} {% for input in op["inputs"] %}
{{add_input(loop.index0, input, api_name)}}; {{add_input(loop.index0, input, op_name)}};
{% endfor %} {% endfor %}
{% for output in api["outputs"] %} {% for output in op["outputs"] %}
{{add_output(loop.index0, output, api_name)}}; {{add_output(loop.index0, output, op_name)}};
{% endfor %} {% endfor %}
{% for attr in api["attrs"] %} {% for attr in op["attrs"] %}
{% if attr["name"] in api["kernel"]["param"] %} {% if attr["name"] in op["kernel"]["param"] %}
{{add_attr(loop.index0, attr, api_name)}}; {{add_attr(loop.index0, attr, op_name)}};
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% endfilter %} {% endfilter %}
AddComment(R"DOC( AddComment(R"DOC(
TODO: Documentation of {{api_name}} op. TODO: Documentation of {{op_name}} op.
)DOC"); )DOC");
} }
}; };
...@@ -76,7 +76,7 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty ...@@ -76,7 +76,7 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty
{%- endif %} {%- endif %}
{%- endmacro %} {%- endmacro %}
{# process default value for attributes, some attribute has different types and different default values in api & opmaker #} {# process default value for attributes, some attribute has different types and different default values in op & opmaker #}
{% macro process_default_value(attr) %}{# inline #} {% macro process_default_value(attr) %}{# inline #}
{% set default_value = attr["default_value"] %} {% set default_value = attr["default_value"] %}
{% set typename = attr["typename"] %} {% set typename = attr["typename"] %}
...@@ -97,22 +97,22 @@ static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}} ...@@ -97,22 +97,22 @@ static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}}
{# --------------------------------------- name mapping ---------------------------------------------- #} {# --------------------------------------- name mapping ---------------------------------------------- #}
{% macro name_map(api) %} {% macro name_map(op) %}
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %} {% set kernel_args = op["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}}; {{get_input_list(op["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs; paddle::small_vector<const char*> attrs;
{% for attr in api["attrs"]%} {% for attr in op["attrs"]%}
{% filter indent(2)%} {% filter indent(2)%}
{{get_an_attr(attr)}}; {{get_an_attr(attr)}};
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(api["outputs"], kernel_args)}}; {{get_output_list(op["outputs"], kernel_args)}};
{% if api["kernel"]["func"] | length == 1 %} {% if op["kernel"]["func"] | length == 1 %}
KernelSignature sig("{{api["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs)); KernelSignature sig("{{op["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs));
return sig; return sig;
{% else %}{# it has kernel for selected rows #} {% else %}{# it has kernel for selected rows #}
const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{api["kernel"]["func"][1]}}" : "{{api["kernel"]["func"][0]}}"; const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{op["kernel"]["func"][1]}}" : "{{op["kernel"]["func"][0]}}";
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs)); KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig; return sig;
{%endif%} {%endif%}
...@@ -121,9 +121,9 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu ...@@ -121,9 +121,9 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu
/* /*
****************************************************************** ******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py' NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping: All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}} {{op | cartesian_prod_mapping}}
****************************************************************** ******************************************************************
*/ */
{% endmacro %} {% endmacro %}
...@@ -152,20 +152,20 @@ ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}} ...@@ -152,20 +152,20 @@ ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- endfor %} {%- endfor %}
{%- endmacro %} {%- endmacro %}
{% macro sparse_op_name_map(api) %} {% macro sparse_op_name_map(op) %}
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %} {% set kernel_args = op["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}}; {{get_input_list(op["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs; paddle::small_vector<const char*> attrs;
{% for attr in api["attrs"]%} {% for attr in op["attrs"]%}
{% filter indent(2)%} {% filter indent(2)%}
{{get_an_attr(attr)}}; {{get_an_attr(attr)}};
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(api["outputs"], kernel_args)}}; {{get_output_list(op["outputs"], kernel_args)}};
const char* kernel_name = "unregistered"; const char* kernel_name = "unregistered";
{{get_kernel_dispatch(api["inputs"], api["kernel"])}} {{get_kernel_dispatch(op["inputs"], op["kernel"])}}
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs)); KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig; return sig;
} }
...@@ -173,19 +173,19 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu ...@@ -173,19 +173,19 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu
/* /*
****************************************************************** ******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py' NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping: All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}} {{op | cartesian_prod_mapping}}
****************************************************************** ******************************************************************
*/ */
{% endmacro %} {% endmacro %}
{% macro register_base_kernel_name(api) %} {% macro register_base_kernel_name(op) %}
PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}}); PD_REGISTER_BASE_KERNEL_NAME({{op["op_name"]}}, {{op["name"]}});
{%- endmacro %} {%- endmacro %}
{% macro register_name_map(api) %} {% macro register_name_map(op) %}
PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["op_name"] | to_pascal_case}}OpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN({{op["op_name"]}}, phi::{{op["op_name"] | to_pascal_case}}OpArgumentMapping);
{%- endmacro %} {%- endmacro %}
{% macro get_input_list(inputs, kernel_args) %}{# inline #} {% macro get_input_list(inputs, kernel_args) %}{# inline #}
...@@ -228,14 +228,14 @@ paddle::small_vector<const char*> outputs { ...@@ -228,14 +228,14 @@ paddle::small_vector<const char*> outputs {
} }
{%- endmacro %} {%- endmacro %}
{% macro get_expected_kernel(api) %} {% macro get_expected_kernel(op) %}
{% set kernel = api["kernel"] %} {% set kernel = op["kernel"] %}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
{%if kernel["data_type"] is not none %}{# data type ---------------------------------#} {%if kernel["data_type"] is not none %}{# data type ---------------------------------#}
{% if kernel["data_type"]["candidates"] | length == 1 %} {% if kernel["data_type"]["candidates"] | length == 1 %}
{% set data_type_arg = kernel["data_type"]["candidates"][0] %} {% set data_type_arg = kernel["data_type"]["candidates"][0] %}
{% set inputs = api["inputs"] | map(attribute="name") | list %} {% set inputs = op["inputs"] | map(attribute="name") | list %}
{% if data_type_arg in inputs %} {% if data_type_arg in inputs %}
auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}}); auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}});
{% else %}{# it is an attribute and probably named dtype#} {% else %}{# it is an attribute and probably named dtype#}
...@@ -254,68 +254,68 @@ framework::OpKernelType GetExpectedKernelType( ...@@ -254,68 +254,68 @@ framework::OpKernelType GetExpectedKernelType(
{% endmacro %} {% endmacro %}
{# --------------------------------------- operator ---------------------------------------------- #} {# --------------------------------------- operator ---------------------------------------------- #}
{% macro operator(api) %} {% macro operator(op) %}
class {{api["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKernel { class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
{# ----------- get expected kernel type function -------------------------- #} {# ----------- get expected kernel type function -------------------------- #}
{% set kernel = api["kernel"] %} {% set kernel = op["kernel"] %}
{% if kernel["data_type"] is not none %} {% if kernel["data_type"] is not none %}
protected: protected:
{% filter indent(2, True)%} {% filter indent(2, True)%}
{{get_expected_kernel(api)}} {{get_expected_kernel(op)}}
{% endfilter %} {% endfilter %}
{% endif %} {% endif %}
}; };
DECLARE_INFER_SHAPE_FUNCTOR({{api["op_name"]}}, {{api["op_name"] | to_pascal_case}}InferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor,
PD_INFER_META(phi::{{api["infer_meta"]["func"]}})); PD_INFER_META(phi::{{op["infer_meta"]["func"]}}));
{# inplace inferer #} {# inplace inferer #}
{% if api["inplace"] is not none %} {% if op["inplace"] is not none %}
{% set inplace_map %} {% set inplace_map %}
{% for source, target in api["inplace"].items() %} {% for source, target in op["inplace"].items() %}
{{"{"}}{{target | to_opmaker_name}}, {{source | to_opmaker_name}}{{"}"}}{{", " if not loop.last}} {{"{"}}{{target | to_opmaker_name}}, {{source | to_opmaker_name}}{{"}"}}{{", " if not loop.last}}
{%- endfor %} {%- endfor %}
{%- endset %} {%- endset %}
DECLARE_INPLACE_OP_INFERER({{api["op_name"] | to_pascal_case}}InplaceInferer, DECLARE_INPLACE_OP_INFERER({{op["op_name"] | to_pascal_case}}InplaceInferer,
{{inplace_map}}); {{inplace_map}});
{% endif %} {% endif %}
{# no_need_buffer inferer #} {# no_need_buffer inferer #}
{% if api["no_need_buffer"] is not none %} {% if op["no_need_buffer"] is not none %}
DECLARE_NO_NEED_BUFFER_VARS_INFERER({{api["op_name"] | to_pascal_case}}NoNeedBufferVarInferer, DECLARE_NO_NEED_BUFFER_VARS_INFERER({{op["op_name"] | to_pascal_case}}NoNeedBufferVarInferer,
{{api["no_need_buffer"] | map("to_opmaker_name") | join(", ")}}); {{op["no_need_buffer"] | map("to_opmaker_name") | join(", ")}});
{% endif %} {% endif %}
{% endmacro%} {% endmacro%}
{% macro register_op_with_components(api) %} {% macro register_op_with_components(op) %}
{% set name = api["op_name"] %} {% set name = op["op_name"] %}
REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in api %}{# it is a forward api #} {% if not "forward" in op %}{# it is a forward op #}
ops::{{name | to_pascal_case}}OpMaker, ops::{{name | to_pascal_case}}OpMaker,
{% endif %} {% endif %}
{% if "backward" in api and api["backward"] is not none %}{# backward #} {% if "backward" in op and op["backward"] is not none %}{# backward #}
{% set backward_name = api["backward"] %} {% set backward_name = op["backward"] %}
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>,
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>,
{% else %} {% else %}
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
{% endif %} {% endif %}
{% if api is supports_inplace %}{# inplace#} {% if op is supports_inplace %}{# inplace#}
ops::{{name | to_pascal_case}}InplaceInferer, ops::{{name | to_pascal_case}}InplaceInferer,
{% endif %} {% endif %}
{% if api is supports_no_need_buffer %}{# no_need_buffer #} {% if op is supports_no_need_buffer %}{# no_need_buffer #}
ops::{{name | to_pascal_case}}NoNeedBufferVarInferer, ops::{{name | to_pascal_case}}NoNeedBufferVarInferer,
{% endif %} {% endif %}
ops::{{name | to_pascal_case}}InferShapeFunctor); ops::{{name | to_pascal_case}}InferShapeFunctor);
{% endmacro %} {% endmacro %}
{% macro register_op_version(api) %} {% macro register_op_version(op) %}
{% if "version" in api %} {% if "version" in op %}
{% set name = api["op_name"] %} {% set name = op["op_name"] %}
REGISTER_OP_VERSION({{name}}) REGISTER_OP_VERSION({{name}})
{% for checkpoint in api["version"]%} {% for checkpoint in op["version"]%}
.AddCheckpoint( .AddCheckpoint(
R"ROC({{checkpoint["checkpoint"]}})ROC", R"ROC({{checkpoint["checkpoint"]}})ROC",
paddle::framework::compatible::OpVersionDesc() paddle::framework::compatible::OpVersionDesc()
...@@ -354,14 +354,14 @@ REGISTER_OP_VERSION({{name}}) ...@@ -354,14 +354,14 @@ REGISTER_OP_VERSION({{name}})
{# --------------------------------------- backward op maker ---------------------------------------------- #} {# --------------------------------------- backward op maker ---------------------------------------------- #}
{% macro backward_op_maker(api, forward_api) %} {% macro backward_op_maker(op, forward_op ) %}
{% set name = api["op_name"] %} {% set name = op["op_name"] %}
{% set forward_input_names = api["forward"]["inputs"] | map(attribute="name") | list %} {% set forward_input_names = op["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_output_names = api["forward"]["outputs"] | map(attribute="name") | list %} {% set forward_output_names = op["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_attr_names = api["forward"]["attrs"] | map(attribute="name") | list %} {% set forward_attr_names = op["forward"]["attrs"] | map(attribute="name") | list %}
{% set forward_input_orig_names = forward_api["inputs"] | map(attribute="name") | list %} {% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %}
{% set forward_output_orig_names = forward_api["outputs"] | map(attribute="name") | list %} {% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %}
{% set forward_attr_orig_names = forward_api["attrs"] | map(attribute="name") | list %} {% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %}
template <typename T> template <typename T>
class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> { class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -371,7 +371,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -371,7 +371,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
void Apply(GradOpPtr<T> grad_op) const override { void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("{{name}}"); grad_op->SetType("{{name}}");
{% for input in api["inputs"] %} {% for input in op["inputs"] %}
grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward( grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["name"], input["name"],
forward_input_names, forward_input_names,
...@@ -380,7 +380,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -380,7 +380,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
forward_output_orig_names)}}); forward_output_orig_names)}});
{% endfor %} {% endfor %}
{% for output in api["outputs"] %} {% for output in op["outputs"] %}
grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward( grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["name"], output["name"],
forward_input_names, forward_input_names,
...@@ -390,7 +390,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -390,7 +390,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% endfor %} {% endfor %}
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
{% for attr in api["attrs"] %} {% for attr in op["attrs"] %}
{% set attr_name = attr["name"] %} {% set attr_name = attr["name"] %}
{% if attr_name in forward_attr_names %} {% if attr_name in forward_attr_names %}
{% if attr["typename"] == "IntArray" %} {% if attr["typename"] == "IntArray" %}
......
...@@ -5,20 +5,20 @@ ...@@ -5,20 +5,20 @@
namespace phi { namespace phi {
{% for api in apis %} {% for op in ops %}
{% if api is base_api %} {% if op is base_op %}
{{sparse_op_name_map(api)}} {{sparse_op_name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% for api in backward_apis %} {% for op in backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{sparse_op_name_map(api)}} {{sparse_op_name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace phi } // namespace phi
{% for api in apis + backward_apis %} {% for op in ops + backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{register_name_map(api)}} {{register_name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -19,31 +19,31 @@ namespace operators { ...@@ -19,31 +19,31 @@ namespace operators {
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
{% for api in apis %} {% for op in ops %}
{% if api is base_api %} {% if op is base_op %}
{{op_maker(api)}} {{op_maker(op)}}
{{operator(api)}} {{operator(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% for api in backward_apis %} {% for op in backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}} {{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
{{operator(api)}} {{operator(op)}}
{% else %} {% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}} {{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
{% for api in apis + backward_apis %} {% for op in ops + backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{register_op_with_components(api)}} {{register_op_with_components(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -41,20 +41,20 @@ def is_initializer_list(s): ...@@ -41,20 +41,20 @@ def is_initializer_list(s):
return s == "{}" return s == "{}"
def is_base_api(api): def is_base_op(op):
return "kernel" in api and "infer_meta" in api return "kernel" in op and "infer_meta" in op
def supports_selected_rows_kernel(api): def supports_selected_rows_kernel(op):
return is_base_api(api) and len(api["kernel"]["func"]) == 2 return is_base_op(op) and len(op["kernel"]["func"]) == 2
def supports_inplace(api): def supports_inplace(op):
return api['inplace'] is not None return op['inplace'] is not None
def supports_no_need_buffer(api): def supports_no_need_buffer(op):
for input in api["inputs"]: for input in op["inputs"]:
if input["no_need_buffer"]: if input["no_need_buffer"]:
return True return True
return False return False
...@@ -15,8 +15,6 @@ add_subdirectory(backends) ...@@ -15,8 +15,6 @@ add_subdirectory(backends)
add_subdirectory(kernels) add_subdirectory(kernels)
# phi infermeta # phi infermeta
add_subdirectory(infermeta) add_subdirectory(infermeta)
# phi operator definitions
add_subdirectory(ops)
# phi tools # phi tools
add_subdirectory(tools) add_subdirectory(tools)
# phi tests # phi tests
...@@ -36,7 +34,6 @@ set(PHI_DEPS ...@@ -36,7 +34,6 @@ set(PHI_DEPS
arg_map_context arg_map_context
infermeta infermeta
lod_utils lod_utils
op_compat_infos
sparse_csr_tensor sparse_csr_tensor
sparse_coo_tensor sparse_coo_tensor
string_tensor string_tensor
......
...@@ -94,204 +94,10 @@ set(wrapped_infermeta_header_file ...@@ -94,204 +94,10 @@ set(wrapped_infermeta_header_file
set(wrapped_infermeta_source_file set(wrapped_infermeta_source_file
${CMAKE_SOURCE_DIR}/paddle/phi/infermeta/generated.cc) ${CMAKE_SOURCE_DIR}/paddle/phi/infermeta/generated.cc)
# op extra info file
set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/ops_extra_info_gen.py)
set(op_compat_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc)
if(NOT PYTHONINTERP_FOUND) if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED) find_package(PythonInterp REQUIRED)
endif() endif()
# install extra dependencies
if(${PYTHON_VERSION_STRING} VERSION_LESS "3.6.2")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml
typing-extensions>=4.1.1 jinja2==2.11.3)
else()
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml jinja2
typing-extensions)
endif()
# parse apis
set(parsed_api_dir ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/parsed_apis)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_sparse_ops_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc)
set(generated_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc)
set(generated_sparse_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc)
message(
"parse api yamls:
- ${api_yaml_file}
- ${legacy_api_yaml_file}
- ${bw_api_yaml_file}
- ${legacy_bw_api_yaml_file}")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_api_dir}
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./ops.yaml
--output_path ./parsed_apis/ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_ops.yaml --output_path ./parsed_apis/legacy_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./backward.yaml
--output_path ./parsed_apis/backward_ops.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_backward.yaml --output_path
./parsed_apis/legacy_backward_ops.parsed.yaml --backward
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_ops.yaml --output_path ./parsed_apis/sparse_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_backward.yaml --output_path
./parsed_apis/sparse_backward.parsed.yaml --backward RESULTS_VARIABLE
_results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "api yaml parsing failed, exiting.")
endif()
endforeach()
# validation of api yamls
message("validate api yaml:
- ${parsed_api_dir}/ops.parsed.yaml
- ${parsed_api_dir}/backward_ops.parsed.yaml")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/ops.parsed.yaml ./parsed_apis/legacy_ops.parsed.yaml
--backward_yaml_paths ./parsed_apis/backward_ops.parsed.yaml
./parsed_apis/legacy_backward_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/sparse_ops.parsed.yaml --backward_yaml_paths
./parsed_apis/sparse_backward.parsed.yaml
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
endforeach()
# code generation for op, op makers, and argument mapping functions
message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
)
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/generate_op.py --ops_yaml_path
./parsed_apis/ops.parsed.yaml --backward_yaml_path
./parsed_apis/backward_ops.parsed.yaml --op_version_yaml_path
op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path
"${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
COMMAND
${PYTHON_EXECUTABLE} generator/generate_sparse_op.py --ops_yaml_path
./parsed_apis/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_apis/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_op_path}.tmp" "${generated_op_path}")
message("copy if different ${generated_op_path}.tmp ${generated_op_path}")
elseif(EXISTS "${generated_op_path}.tmp")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_op_path}.tmp"
"${generated_op_path}")
message("copy ${generated_op_path}.tmp ${generated_op_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f "${generated_op_path}")
message("remove ${generated_op_path}")
endif()
if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS
"${generated_sparse_ops_path}")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}")
message(
"copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}"
)
elseif(EXISTS "${generated_sparse_ops_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp"
"${generated_sparse_ops_path}")
message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_ops_path}")
message("remove ${generated_sparse_ops_path}")
endif()
if(EXISTS "${generated_argument_mapping_path}.tmp"
AND EXISTS "${generated_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy if different ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
elseif(EXISTS "${generated_argument_mapping_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_argument_mapping_path}")
message("remove ${generated_argument_mapping_path}")
endif()
if(EXISTS "${generated_sparse_argument_mapping_path}.tmp"
AND EXISTS "${generated_sparse_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_argument_mapping_path}")
message("remove ${generated_sparse_argument_mapping_path}")
endif()
# generate ops extra info
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path
${op_compat_yaml_file} --ops_extra_info_path ${ops_extra_info_file})
message("generate ${ops_extra_info_file}")
# generate forward api # generate forward api
add_custom_command( add_custom_command(
OUTPUT ${api_header_file} ${api_source_file} OUTPUT ${api_header_file} ${api_source_file}
......
set(op_utils_header
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h.tmp
CACHE INTERNAL "op_args_fns.cc file")
set(op_utils_header_final
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h)
file(
WRITE ${op_utils_header}
"// Generated by the paddle/phi/ops/compat/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${op_utils_header}
"#include \"paddle/phi/core/compat/op_utils.h\"\n\n")
# Automatically generate the registration code of all arg map functions
# and compile the corresponding target to avoid frequent code conflicts
# when writing to same file
register_op_utils(op_compat_infos DEPS op_utils)
copy_if_different(${op_utils_header} ${op_utils_header_final})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册