未验证 提交 f369b2b1 编写于 作者: C Chen Weihang 提交者: GitHub

[PHI decoupling] Move fluid op generator into fluid (#47714)

* move fluid op generator into fluid

* remove parsed op

* resolve sig undef error

* append python interp find logic

* remove dup code
上级 0a051297
...@@ -75,8 +75,7 @@ paddle/fluid/operators/generated_op.cc ...@@ -75,8 +75,7 @@ paddle/fluid/operators/generated_op.cc
paddle/fluid/operators/generated_sparse_op.cc paddle/fluid/operators/generated_sparse_op.cc
paddle/phi/ops/compat/generated_sig.cc paddle/phi/ops/compat/generated_sig.cc
paddle/phi/ops/compat/generated_sparse_sig.cc paddle/phi/ops/compat/generated_sparse_sig.cc
paddle/phi/api/yaml/parsed_apis/ paddle/fluid/operators/generator/parsed_ops/
python/paddle/utils/code_gen/
paddle/fluid/pybind/tmp_eager_op_function_impl.h paddle/fluid/pybind/tmp_eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h
......
...@@ -111,7 +111,7 @@ function(kernel_declare TARGET_LIST) ...@@ -111,7 +111,7 @@ function(kernel_declare TARGET_LIST)
endfunction() endfunction()
function(append_op_util_declare TARGET) function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content) file(READ ${TARGET} target_content)
string( string(
REGEX REGEX
MATCH MATCH
...@@ -134,13 +134,10 @@ function(register_op_utils TARGET_NAME) ...@@ -134,13 +134,10 @@ function(register_op_utils TARGET_NAME)
cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}" cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN}) "${multiValueArgs}" ${ARGN})
file( file(GLOB SIGNATURES "${PADDLE_SOURCE_DIR}/paddle/phi/ops/compat/*_sig.cc")
GLOB SIGNATURES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"*_sig.cc")
foreach(target ${SIGNATURES}) foreach(target ${SIGNATURES})
append_op_util_declare(${target}) append_op_util_declare(${target})
list(APPEND utils_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${target}) list(APPEND utils_srcs ${target})
endforeach() endforeach()
cc_library( cc_library(
......
...@@ -503,7 +503,8 @@ if(WITH_XPU) ...@@ -503,7 +503,8 @@ if(WITH_XPU)
phi_utils phi_utils
kernel_factory kernel_factory
infershape_utils infershape_utils
op_utils) op_utils
op_compat_infos)
else() else()
cc_library( cc_library(
operator operator
...@@ -528,7 +529,8 @@ else() ...@@ -528,7 +529,8 @@ else()
phi_utils phi_utils
kernel_factory kernel_factory
infershape_utils infershape_utils
op_utils) op_utils
op_compat_infos)
endif() endif()
cc_test( cc_test(
......
include(operators) include(operators)
add_subdirectory(generator)
# solve "math constants not defined" problems caused by the order of inclusion # solve "math constants not defined" problems caused by the order of inclusion
# of <cmath> and the definition of macro _USE_MATH_DEFINES # of <cmath> and the definition of macro _USE_MATH_DEFINES
add_definitions(-D_USE_MATH_DEFINES) add_definitions(-D_USE_MATH_DEFINES)
......
# phi auto cmake utils
include(phi)
# set yaml file path
set(op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/ops.yaml)
set(legacy_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/legacy_ops.yaml)
set(bw_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/backward.yaml)
set(legacy_bw_op_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/legacy_backward.yaml)
set(sparse_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_ops.yaml)
set(sparse_bw_op_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_backward.yaml)
if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED)
endif()
# install extra dependencies
if(${PYTHON_VERSION_STRING} VERSION_LESS "3.6.2")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml
typing-extensions>=4.1.1 jinja2==2.11.3)
else()
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml jinja2
typing-extensions)
endif()
# parse ops
set(parsed_op_dir
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_sparse_ops_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc)
set(generated_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc)
set(generated_sparse_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc)
message(
"parse op yamls:
- ${op_yaml_file}
- ${legacy_op_yaml_file}
- ${bw_op_yaml_file}
- ${legacy_bw_op_yaml_file}")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir}
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${op_yaml_file}
--output_path ./parsed_ops/ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${legacy_op_yaml_file}
--output_path ./parsed_ops/legacy_ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${bw_op_yaml_file}
--output_path ./parsed_ops/backward_ops.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${legacy_bw_op_yaml_file}
--output_path ./parsed_ops/legacy_backward_ops.parsed.yaml --backward
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_op_yaml_file}
--output_path ./parsed_ops/sparse_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_bw_op_yaml_file}
--output_path ./parsed_ops/sparse_backward.parsed.yaml --backward
RESULTS_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "op yaml parsing failed, exiting.")
endif()
endforeach()
# validation of op yamls
message("validate op yaml:
- ${parsed_op_dir}/ops.parsed.yaml
- ${parsed_op_dir}/backward_ops.parsed.yaml")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
./parsed_ops/ops.parsed.yaml ./parsed_ops/legacy_ops.parsed.yaml
--backward_yaml_paths ./parsed_ops/backward_ops.parsed.yaml
./parsed_ops/legacy_backward_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
./parsed_ops/sparse_ops.parsed.yaml --backward_yaml_paths
./parsed_ops/sparse_backward.parsed.yaml
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
endforeach()
# code generation for op, op makers, and argument mapping functions
message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
)
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
${PYTHON_EXECUTABLE} generate_op.py --ops_yaml_path
./parsed_ops/ops.parsed.yaml --backward_yaml_path
./parsed_ops/backward_ops.parsed.yaml --op_version_yaml_path
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
COMMAND
${PYTHON_EXECUTABLE} generate_sparse_op.py --ops_yaml_path
./parsed_ops/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_ops/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_op_path}.tmp" "${generated_op_path}")
message("copy if different ${generated_op_path}.tmp ${generated_op_path}")
elseif(EXISTS "${generated_op_path}.tmp")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_op_path}.tmp"
"${generated_op_path}")
message("copy ${generated_op_path}.tmp ${generated_op_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f "${generated_op_path}")
message("remove ${generated_op_path}")
endif()
if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS
"${generated_sparse_ops_path}")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}")
message(
"copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}"
)
elseif(EXISTS "${generated_sparse_ops_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp"
"${generated_sparse_ops_path}")
message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_ops_path}")
message("remove ${generated_sparse_ops_path}")
endif()
if(EXISTS "${generated_argument_mapping_path}.tmp"
AND EXISTS "${generated_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy if different ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
elseif(EXISTS "${generated_argument_mapping_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_argument_mapping_path}")
message("remove ${generated_argument_mapping_path}")
endif()
if(EXISTS "${generated_sparse_argument_mapping_path}.tmp"
AND EXISTS "${generated_sparse_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_argument_mapping_path}")
message("remove ${generated_sparse_argument_mapping_path}")
endif()
# op extra info file
set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/ops_extra_info_gen.py)
set(op_compat_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc)
# generate ops extra info
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path
${op_compat_yaml_file} --ops_extra_info_path ${ops_extra_info_file})
message("generate ${ops_extra_info_file}")
set(op_utils_header
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h.tmp
CACHE INTERNAL "op_args_fns.cc file")
set(op_utils_header_final
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h)
file(
WRITE ${op_utils_header}
"// Generated by the paddle/fluid/operators/generator/CMakeLists.txt. DO NOT EDIT!\n\n"
)
file(APPEND ${op_utils_header}
"#include \"paddle/phi/core/compat/op_utils.h\"\n\n")
# Automatically generate the registration code of all arg map functions
# and compile the corresponding target to avoid frequent code conflicts
# when writing to same file
register_op_utils(op_compat_infos DEPS op_utils)
copy_if_different(${op_utils_header} ${op_utils_header_final})
...@@ -20,35 +20,35 @@ import yaml ...@@ -20,35 +20,35 @@ import yaml
from parse_utils import cross_validate, to_named_dict from parse_utils import cross_validate, to_named_dict
def main(forward_api_yaml_paths, backward_api_yaml_paths): def main(forward_op_yaml_paths, backward_op_yaml_paths):
apis = {} ops = {}
for api_yaml_path in chain(forward_api_yaml_paths, backward_api_yaml_paths): for op_yaml_path in chain(forward_op_yaml_paths, backward_op_yaml_paths):
with open(api_yaml_path, "rt", encoding="utf-8") as f: with open(op_yaml_path, "rt", encoding="utf-8") as f:
api_list = yaml.safe_load(f) op_list = yaml.safe_load(f)
if api_list is not None: if op_list is not None:
apis.update(to_named_dict((api_list))) ops.update(to_named_dict((op_list)))
cross_validate(apis) cross_validate(ops)
if __name__ == "__main__": if __name__ == "__main__":
current_dir = Path(__file__).parent / "temp" current_dir = Path(__file__).parent / "temp"
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Parse api yaml into canonical format." description="Parse op yaml into canonical format."
) )
parser.add_argument( parser.add_argument(
'--forward_yaml_paths', '--forward_yaml_paths',
type=str, type=str,
nargs='+', nargs='+',
default=str(current_dir / "api.parsed.yaml"), default=str(current_dir / "op .parsed.yaml"),
help="forward api yaml file.", help="forward op yaml file.",
) )
parser.add_argument( parser.add_argument(
'--backward_yaml_paths', '--backward_yaml_paths',
type=str, type=str,
nargs='+', nargs='+',
default=str(current_dir / "backward_api.parsed.yaml"), default=str(current_dir / "backward_op .parsed.yaml"),
help="backward api yaml file.", help="backward op yaml file.",
) )
args = parser.parse_args() args = parser.parse_args()
......
...@@ -102,12 +102,12 @@ def to_pascal_case(s): ...@@ -102,12 +102,12 @@ def to_pascal_case(s):
def to_input_name(s): def to_input_name(s):
"""find input variable name in api yaml for higher order backward api. """find input variable name in op yaml for higher order backward op .
x -> dx x -> dx
x -> d2x x -> d2x
x -> d3x x -> d3x
NOTE: for first order backward api NOTE: for first order backward op
x -> x_grad x -> x_grad
is more common. is more common.
""" """
...@@ -137,16 +137,14 @@ def cartesian_prod_attrs(attrs): ...@@ -137,16 +137,14 @@ def cartesian_prod_attrs(attrs):
return combinations return combinations
def cartesian_prod_mapping(api): def cartesian_prod_mapping(op):
kernels = api["kernel"]["func"] kernels = op["kernel"]["func"]
inputs = [ inputs = [
x["name"] for x in api["inputs"] if x["name"] in api["kernel"]["param"] x["name"] for x in op["inputs"] if x["name"] in op["kernel"]["param"]
] ]
inputs = [to_opmaker_name_cstr(input) for input in inputs] inputs = [to_opmaker_name_cstr(input) for input in inputs]
attrs = cartesian_prod_attrs(api["attrs"]) attrs = cartesian_prod_attrs(op["attrs"])
outputs = [ outputs = [to_opmaker_name_cstr(output["name"]) for output in op["outputs"]]
to_opmaker_name_cstr(output["name"]) for output in api["outputs"]
]
def vec(items): def vec(items):
return "{" + ', '.join(items) + "}" return "{" + ', '.join(items) + "}"
......
...@@ -26,7 +26,7 @@ from filters import ( ...@@ -26,7 +26,7 @@ from filters import (
to_pascal_case, to_pascal_case,
) )
from tests import ( from tests import (
is_base_api, is_base_op,
is_vec, is_vec,
is_scalar, is_scalar,
is_initializer_list, is_initializer_list,
...@@ -51,7 +51,7 @@ env.filters["to_pascal_case"] = to_pascal_case ...@@ -51,7 +51,7 @@ env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.tests["base_api"] = is_base_api env.tests["base_op"] = is_base_op
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list env.tests["initializer_list"] = is_initializer_list
...@@ -59,126 +59,127 @@ env.tests["supports_inplace"] = supports_inplace ...@@ -59,126 +59,127 @@ env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer env.tests["supports_no_need_buffer"] = supports_no_need_buffer
def restruct_io(api): def restruct_io(op):
api["input_dict"] = to_named_dict(api["inputs"]) op["input_dict"] = to_named_dict(op["inputs"])
api["attr_dict"] = to_named_dict(api["attrs"]) op["attr_dict"] = to_named_dict(op["attrs"])
api["output_dict"] = to_named_dict(api["outputs"]) op["output_dict"] = to_named_dict(op["outputs"])
return api return op
# replace name of op and params for OpMaker # replace name of op and params for OpMaker
def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
def get_api_and_op_name(api_item): def get_op_and_op_name(op_item):
names = api_item.split('(') names = op_item.split('(')
if len(names) == 1: if len(names) == 1:
return names[0].strip(), names[0].strip() return names[0].strip(), names[0].strip()
else: else:
return names[0].strip(), names[1].split(')')[0].strip() return names[0].strip(), names[1].split(')')[0].strip()
def update_api_attr_name(attrs, attrs_alias_map): def update_op_attr_name(attrs, attrs_alias_map):
for attr_item in attrs: for attr_item in attrs:
if attr_item['name'] in attrs_alias_map: if attr_item['name'] in attrs_alias_map:
attr_item['name'] = attrs_alias_map[attr_item['name']] attr_item['name'] = attrs_alias_map[attr_item['name']]
for api_args in api_op_map: for op_args in op_op_map:
api_name, op_name = get_api_and_op_name(api_args['op']) new_op_name, op_name = get_op_and_op_name(op_args['op'])
if api_name not in forward_api_dict: if new_op_name not in forward_op_dict:
continue continue
forward_api_item = forward_api_dict[api_name] forward_op_item = forward_op_dict[new_op_name]
has_backward = True if forward_api_item['backward'] else False has_backward = True if forward_op_item['backward'] else False
if has_backward: if has_backward:
backward_api_item = backward_api_dict[forward_api_item['backward']] backward_op_item = backward_op_dict[forward_op_item['backward']]
if api_name != op_name: if new_op_name != op_name:
forward_api_item['op_name'] = op_name forward_op_item['op_name'] = op_name
if 'backward' in api_args and has_backward: if 'backward' in op_args and has_backward:
backward_op_list = api_args['backward'].split(',') backward_op_list = op_args['backward'].split(',')
bw_api_name, bw_op_name = get_api_and_op_name(backward_op_list[0]) _, bw_op_name = get_op_and_op_name(backward_op_list[0])
forward_api_item['backward'] = bw_op_name forward_op_item['backward'] = bw_op_name
backward_api_item['op_name'] = bw_op_name backward_op_item['op_name'] = bw_op_name
# for double grad # for double grad
if len(backward_op_list) > 1: if len(backward_op_list) > 1:
double_grad_api_name, double_grad_op_name = get_api_and_op_name( (
backward_op_list[1] new_double_grad_op_name,
) double_grad_op_name,
double_grad_item = backward_api_dict[double_grad_api_name] ) = get_op_and_op_name(backward_op_list[1])
backward_api_item['backward'] = double_grad_op_name double_grad_item = backward_op_dict[new_double_grad_op_name]
backward_op_item['backward'] = double_grad_op_name
double_grad_item['op_name'] = double_grad_op_name double_grad_item['op_name'] = double_grad_op_name
if 'attrs' in api_args: if 'attrs' in op_args:
update_api_attr_name( update_op_attr_name(
double_grad_item['attrs'], api_args['attrs'] double_grad_item['attrs'], op_args['attrs']
) )
update_api_attr_name( update_op_attr_name(
double_grad_item['forward']['attrs'], api_args['attrs'] double_grad_item['forward']['attrs'], op_args['attrs']
) )
# for triple grad # for triple grad
if len(backward_op_list) > 2: if len(backward_op_list) > 2:
( (
triple_grad_api_name, new_triple_grad_op_name,
triple_grad_op_name, triple_grad_op_name,
) = get_api_and_op_name(backward_op_list[2]) ) = get_op_and_op_name(backward_op_list[2])
triple_grad_item = backward_api_dict[triple_grad_api_name] triple_grad_item = backward_op_dict[new_triple_grad_op_name]
double_grad_item['backward'] = triple_grad_op_name double_grad_item['backward'] = triple_grad_op_name
triple_grad_item['op_name'] = triple_grad_op_name triple_grad_item['op_name'] = triple_grad_op_name
if 'attrs' in api_args: if 'attrs' in op_args:
update_api_attr_name( update_op_attr_name(
triple_grad_item['attrs'], api_args['attrs'] triple_grad_item['attrs'], op_args['attrs']
) )
update_api_attr_name( update_op_attr_name(
triple_grad_item['forward']['attrs'], triple_grad_item['forward']['attrs'],
api_args['attrs'], op_args['attrs'],
) )
key_set = ['inputs', 'attrs', 'outputs'] key_set = ['inputs', 'attrs', 'outputs']
args_map = {} args_map = {}
for key in key_set: for key in key_set:
if key in api_args: if key in op_args:
args_map.update(api_args[key]) args_map.update(op_args[key])
for args_item in forward_api_item[key]: for args_item in forward_op_item[key]:
if args_item['name'] in api_args[key]: if args_item['name'] in op_args[key]:
args_item['name'] = api_args[key][args_item['name']] args_item['name'] = op_args[key][args_item['name']]
if has_backward: if has_backward:
for args_item in backward_api_item['forward'][key]: for args_item in backward_op_item['forward'][key]:
if args_item['name'] in api_args[key]: if args_item['name'] in op_args[key]:
args_item['name'] = api_args[key][args_item['name']] args_item['name'] = op_args[key][args_item['name']]
forward_api_item['infer_meta']['param'] = [ forward_op_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in forward_api_item['infer_meta']['param'] for param in forward_op_item['infer_meta']['param']
] ]
forward_api_item['kernel']['param'] = [ forward_op_item['kernel']['param'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['param'] for param in forward_op_item['kernel']['param']
] ]
if forward_api_item['kernel']['data_type']: if forward_op_item['kernel']['data_type']:
forward_api_item['kernel']['data_type']['candidates'] = [ forward_op_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['data_type'][ for param in forward_op_item['kernel']['data_type'][
'candidates' 'candidates'
] ]
] ]
if forward_api_item['kernel']['backend']: if forward_op_item['kernel']['backend']:
forward_api_item['kernel']['backend']['candidates'] = [ forward_op_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['backend']['candidates'] for param in forward_op_item['kernel']['backend']['candidates']
] ]
if forward_api_item['kernel']['layout']: if forward_op_item['kernel']['layout']:
forward_api_item['kernel']['layout']['candidates'] = [ forward_op_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['layout']['candidates'] for param in forward_op_item['kernel']['layout']['candidates']
] ]
if forward_api_item['inplace']: if forward_op_item['inplace']:
inplace_map = {} inplace_map = {}
for key, val in forward_api_item['inplace'].items(): for key, val in forward_op_item['inplace'].items():
if key in args_map: if key in args_map:
key = args_map[key] key = args_map[key]
if val in args_map: if val in args_map:
val = args_map[val] val = args_map[val]
inplace_map[key] = val inplace_map[key] = val
forward_api_item['inplace'] = inplace_map forward_op_item['inplace'] = inplace_map
if has_backward: if has_backward:
for args_item in backward_api_item['inputs']: for args_item in backward_op_item['inputs']:
if args_item['name'] in args_map: if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']] args_item['name'] = args_map[args_item['name']]
elif ( elif (
...@@ -189,10 +190,10 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): ...@@ -189,10 +190,10 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
args_map[args_item['name'][:-5]] + '_grad' args_map[args_item['name'][:-5]] + '_grad'
) )
args_item['name'] = args_map[args_item['name']] args_item['name'] = args_map[args_item['name']]
for args_item in backward_api_item['attrs']: for args_item in backward_op_item['attrs']:
if args_item['name'] in args_map: if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']] args_item['name'] = args_map[args_item['name']]
for args_item in backward_api_item['outputs']: for args_item in backward_op_item['outputs']:
if ( if (
args_item['name'].endswith('_grad') args_item['name'].endswith('_grad')
and args_item['name'][:-5] in args_map and args_item['name'][:-5] in args_map
...@@ -202,73 +203,73 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): ...@@ -202,73 +203,73 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
) )
args_item['name'] = args_map[args_item['name']] args_item['name'] = args_map[args_item['name']]
if 'invoke' in backward_api_item: if 'invoke' in backward_op_item:
backward_api_item['invoke']['args'] = [ backward_op_item['invoke']['args'] = [
args_map[param.strip()] args_map[param.strip()]
if param.strip() in args_map if param.strip() in args_map
else param.strip() else param.strip()
for param in backward_api_item['invoke']['args'].split(',') for param in backward_op_item['invoke']['args'].split(',')
] ]
continue continue
backward_api_item['infer_meta']['param'] = [ backward_op_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in backward_api_item['infer_meta']['param'] for param in backward_op_item['infer_meta']['param']
] ]
backward_api_item['kernel']['param'] = [ backward_op_item['kernel']['param'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['param'] for param in backward_op_item['kernel']['param']
] ]
if backward_api_item['kernel']['data_type']: if backward_op_item['kernel']['data_type']:
backward_api_item['kernel']['data_type']['candidates'] = [ backward_op_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['data_type'][ for param in backward_op_item['kernel']['data_type'][
'candidates' 'candidates'
] ]
] ]
if backward_api_item['kernel']['backend']: if backward_op_item['kernel']['backend']:
backward_api_item['kernel']['backend']['candidates'] = [ backward_op_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['backend'][ for param in backward_op_item['kernel']['backend'][
'candidates' 'candidates'
] ]
] ]
if backward_api_item['kernel']['layout']: if backward_op_item['kernel']['layout']:
backward_api_item['kernel']['layout']['candidates'] = [ backward_op_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['layout'][ for param in backward_op_item['kernel']['layout'][
'candidates' 'candidates'
] ]
] ]
if backward_api_item['no_need_buffer']: if backward_op_item['no_need_buffer']:
backward_api_item['no_need_buffer'] = [ backward_op_item['no_need_buffer'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in backward_api_item['no_need_buffer'] for param in backward_op_item['no_need_buffer']
] ]
if backward_api_item['inplace']: if backward_op_item['inplace']:
inplace_map = {} inplace_map = {}
for key, val in backward_api_item['inplace'].items(): for key, val in backward_op_item['inplace'].items():
if key in args_map: if key in args_map:
key = args_map[key] key = args_map[key]
if val in args_map: if val in args_map:
val = args_map[val] val = args_map[val]
inplace_map[key] = val inplace_map[key] = val
backward_api_item['inplace'] = inplace_map backward_op_item['inplace'] = inplace_map
def process_invoke_op(forward_api_dict, backward_api_dict): def process_invoke_op(forward_op_dict, backward_op_dict):
for bw_api in backward_api_dict.values(): for bw_op in backward_op_dict.values():
if 'invoke' in bw_api: if 'invoke' in bw_op:
invoke_op = bw_api['invoke']['func'] invoke_op = bw_op['invoke']['func']
args_list = bw_api['invoke']['args'] args_list = bw_op['invoke']['args']
args_index = 0 args_index = 0
if invoke_op in forward_api_dict: if invoke_op in forward_op_dict:
reuse_op = forward_api_dict[invoke_op] reuse_op = forward_op_dict[invoke_op]
bw_api['invoke']['inputs'] = [] bw_op['invoke']['inputs'] = []
bw_api['invoke']['attrs'] = [] bw_op['invoke']['attrs'] = []
bw_api['invoke']['outputs'] = [] bw_op['invoke']['outputs'] = []
for input_item in reuse_op['inputs']: for input_item in reuse_op['inputs']:
bw_api['invoke']['inputs'].append( bw_op['invoke']['inputs'].append(
{ {
'name': input_item['name'], 'name': input_item['name'],
'value': args_list[args_index], 'value': args_list[args_index],
...@@ -279,20 +280,20 @@ def process_invoke_op(forward_api_dict, backward_api_dict): ...@@ -279,20 +280,20 @@ def process_invoke_op(forward_api_dict, backward_api_dict):
if args_index < len(args_list): if args_index < len(args_list):
attr_value = ( attr_value = (
f"this->GetAttr(\"{args_list[args_index]}\")" f"this->GetAttr(\"{args_list[args_index]}\")"
if args_list[args_index] in bw_api['attr_dict'] if args_list[args_index] in bw_op['attr_dict']
else args_list[args_index] else args_list[args_index]
) )
bw_api['invoke']['attrs'].append( bw_op['invoke']['attrs'].append(
{'name': attr['name'], 'value': attr_value} {'name': attr['name'], 'value': attr_value}
) )
args_index = args_index + 1 args_index = args_index + 1
else: else:
break break
for idx, output_item in enumerate(reuse_op['outputs']): for idx, output_item in enumerate(reuse_op['outputs']):
bw_api['invoke']['outputs'].append( bw_op['invoke']['outputs'].append(
{ {
'name': output_item['name'], 'name': output_item['name'],
'value': bw_api['outputs'][idx]['name'], 'value': bw_op['outputs'][idx]['name'],
} }
) )
...@@ -306,47 +307,47 @@ def main( ...@@ -306,47 +307,47 @@ def main(
output_arg_map_path, output_arg_map_path,
): ):
with open(ops_yaml_path, "rt") as f: with open(ops_yaml_path, "rt") as f:
apis = yaml.safe_load(f) ops = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis] ops = [restruct_io(op) for op in ops]
forward_api_dict = to_named_dict(apis) forward_op_dict = to_named_dict(ops)
with open(backward_yaml_path, "rt") as f: with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f) backward_ops = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis] backward_ops = [restruct_io(op) for op in backward_ops]
backward_api_dict = to_named_dict(backward_apis) backward_op_dict = to_named_dict(backward_ops)
with open(op_version_yaml_path, "rt") as f: with open(op_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f) op_versions = yaml.safe_load(f)
# add api version info into api # add op version info into op
for api_version in api_versions: for op_version in op_versions:
forward_api_dict[api_version['op']]['version'] = api_version['version'] forward_op_dict[op_version['op']]['version'] = op_version['version']
with open(op_compat_yaml_path, "rt") as f: with open(op_compat_yaml_path, "rt") as f:
api_op_map = yaml.safe_load(f) op_op_map = yaml.safe_load(f)
for api in apis: for op in ops:
api['op_name'] = api['name'] op['op_name'] = op['name']
for bw_api in backward_apis: for bw_op in backward_ops:
bw_api['op_name'] = bw_api['name'] bw_op['op_name'] = bw_op['name']
replace_compat_name(api_op_map, forward_api_dict, backward_api_dict) replace_compat_name(op_op_map, forward_op_dict, backward_op_dict)
# prepare for invoke case # prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict) process_invoke_op(forward_op_dict, backward_op_dict)
# fill backward field for an api if another api claims it as forward # fill backward field for an op if another op claims it as forward
for name, backward_api in backward_api_dict.items(): for name, backward_op in backward_op_dict.items():
forward_name = backward_api["forward"]["name"] forward_name = backward_op["forward"]["name"]
if forward_name in backward_api_dict: if forward_name in backward_op_dict:
forward_api = backward_api_dict[forward_name] forward_op = backward_op_dict[forward_name]
if forward_api["backward"] is None: if forward_op["backward"] is None:
forward_api["backward"] = name forward_op["backward"] = name
api_dict = {} op_dict = {}
api_dict.update(forward_api_dict) op_dict.update(forward_op_dict)
api_dict.update(backward_api_dict) op_dict.update(backward_op_dict)
if len(apis) == 0 and len(backward_apis) == 0: if len(ops) == 0 and len(backward_ops) == 0:
if os.path.isfile(output_op_path): if os.path.isfile(output_op_path):
os.remove(output_op_path) os.remove(output_op_path)
if os.path.isfile(output_arg_map_path): if os.path.isfile(output_arg_map_path):
...@@ -356,19 +357,19 @@ def main( ...@@ -356,19 +357,19 @@ def main(
op_template = env.get_template('op.c.j2') op_template = env.get_template('op.c.j2')
with open(output_op_path, "wt") as f: with open(output_op_path, "wt") as f:
msg = op_template.render( msg = op_template.render(
apis=apis, backward_apis=backward_apis, api_dict=api_dict ops=ops, backward_ops=backward_ops, op_dict=op_dict
) )
f.write(msg) f.write(msg)
ks_template = env.get_template('ks.c.j2') ks_template = env.get_template('ks.c.j2')
with open(output_arg_map_path, 'wt') as f: with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(apis=apis, backward_apis=backward_apis) msg = ks_template.render(ops=ops, backward_ops=backward_ops)
f.write(msg) f.write(msg)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate operator file from api yaml." description="Generate operator file from op yaml."
) )
parser.add_argument( parser.add_argument(
'--ops_yaml_path', type=str, help="parsed ops yaml file." '--ops_yaml_path', type=str, help="parsed ops yaml file."
......
...@@ -26,7 +26,7 @@ from filters import ( ...@@ -26,7 +26,7 @@ from filters import (
to_pascal_case, to_pascal_case,
) )
from tests import ( from tests import (
is_base_api, is_base_op,
is_vec, is_vec,
is_scalar, is_scalar,
is_initializer_list, is_initializer_list,
...@@ -52,7 +52,7 @@ env.filters["to_pascal_case"] = to_pascal_case ...@@ -52,7 +52,7 @@ env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.tests["base_api"] = is_base_api env.tests["base_op"] = is_base_op
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list env.tests["initializer_list"] = is_initializer_list
...@@ -60,65 +60,63 @@ env.tests["supports_inplace"] = supports_inplace ...@@ -60,65 +60,63 @@ env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer env.tests["supports_no_need_buffer"] = supports_no_need_buffer
def restruct_io(api): def restruct_io(op):
api["input_dict"] = to_named_dict(api["inputs"]) op["input_dict"] = to_named_dict(op["inputs"])
api["attr_dict"] = to_named_dict(api["attrs"]) op["attr_dict"] = to_named_dict(op["attrs"])
api["output_dict"] = to_named_dict(api["outputs"]) op["output_dict"] = to_named_dict(op["outputs"])
return api return op
SPARSE_OP_PREFIX = 'sparse_' SPARSE_OP_PREFIX = 'sparse_'
def main( def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
api_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path with open(op_yaml_path, "rt") as f:
): ops = yaml.safe_load(f)
with open(api_yaml_path, "rt") as f: ops = [restruct_io(op) for op in ops]
apis = yaml.safe_load(f) forward_op_dict = to_named_dict(ops)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f: with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f) backward_ops = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis] backward_ops = [restruct_io(op) for op in backward_ops]
backward_api_dict = to_named_dict(backward_apis) backward_op_dict = to_named_dict(backward_ops)
for api in apis: for op in ops:
api['op_name'] = SPARSE_OP_PREFIX + api['name'] op['op_name'] = SPARSE_OP_PREFIX + op['name']
api['name'] = api['op_name'] op['name'] = op['op_name']
if api["backward"] is not None: if op["backward"] is not None:
api["backward"] = SPARSE_OP_PREFIX + api["backward"] op["backward"] = SPARSE_OP_PREFIX + op["backward"]
for bw_api in backward_apis: for bw_op in backward_ops:
bw_api['op_name'] = SPARSE_OP_PREFIX + bw_api['name'] bw_op['op_name'] = SPARSE_OP_PREFIX + bw_op['name']
bw_api['name'] = bw_api['op_name'] bw_op['name'] = bw_op['op_name']
if 'invoke' in bw_api: if 'invoke' in bw_op:
bw_api['invoke']['args'] = [ bw_op['invoke']['args'] = [
param.strip() for param in bw_api['invoke']['args'].split(',') param.strip() for param in bw_op['invoke']['args'].split(',')
] ]
# prepare for invoke case # prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict) process_invoke_op(forward_op_dict, backward_op_dict)
for bw_api in backward_apis: for bw_op in backward_ops:
if 'invoke' in bw_api: if 'invoke' in bw_op:
if bw_api['invoke']['func'] in forward_api_dict: if bw_op['invoke']['func'] in forward_op_dict:
bw_api['invoke']['func'] = ( bw_op['invoke']['func'] = (
SPARSE_OP_PREFIX + bw_api['invoke']['func'] SPARSE_OP_PREFIX + bw_op['invoke']['func']
) )
# fill backward field for an api if another api claims it as forward # fill backward field for an op if another op claims it as forward
for name, backward_api in backward_api_dict.items(): for name, backward_op in backward_op_dict.items():
forward_name = backward_api["forward"]["name"] forward_name = backward_op["forward"]["name"]
if forward_name in backward_api_dict: if forward_name in backward_op_dict:
forward_api = backward_api_dict[forward_name] forward_op = backward_op_dict[forward_name]
if forward_api["backward"] is None: if forward_op["backward"] is None:
forward_api["backward"] = name forward_op["backward"] = name
forward_api["backward"] = SPARSE_OP_PREFIX + forward_api["backward"] forward_op["backward"] = SPARSE_OP_PREFIX + forward_op["backward"]
api_dict = {} op_dict = {}
api_dict.update(forward_api_dict) op_dict.update(forward_op_dict)
api_dict.update(backward_api_dict) op_dict.update(backward_op_dict)
if len(apis) == 0 and len(backward_apis) == 0: if len(ops) == 0 and len(backward_ops) == 0:
if os.path.isfile(output_op_path): if os.path.isfile(output_op_path):
os.remove(output_op_path) os.remove(output_op_path)
if os.path.isfile(output_arg_map_path): if os.path.isfile(output_arg_map_path):
...@@ -128,19 +126,19 @@ def main( ...@@ -128,19 +126,19 @@ def main(
op_template = env.get_template('sparse_op.c.j2') op_template = env.get_template('sparse_op.c.j2')
with open(output_op_path, "wt") as f: with open(output_op_path, "wt") as f:
msg = op_template.render( msg = op_template.render(
apis=apis, backward_apis=backward_apis, api_dict=api_dict ops=ops, backward_ops=backward_ops, op_dict=op_dict
) )
f.write(msg) f.write(msg)
ks_template = env.get_template('sparse_ks.c.j2') ks_template = env.get_template('sparse_ks.c.j2')
with open(output_arg_map_path, 'wt') as f: with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(apis=apis, backward_apis=backward_apis) msg = ks_template.render(ops=ops, backward_ops=backward_ops)
f.write(msg) f.write(msg)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate operator file from api yaml." description="Generate operator file from op yaml."
) )
parser.add_argument( parser.add_argument(
'--ops_yaml_path', type=str, help="parsed sparse ops yaml file." '--ops_yaml_path', type=str, help="parsed sparse ops yaml file."
......
...@@ -16,33 +16,33 @@ import argparse ...@@ -16,33 +16,33 @@ import argparse
import yaml import yaml
from parse_utils import parse_api_entry from parse_utils import parse_op_entry
def main(api_yaml_path, output_path, backward): def main(op_yaml_path, output_path, backward):
with open(api_yaml_path, "rt") as f: with open(op_yaml_path, "rt") as f:
apis = yaml.safe_load(f) ops = yaml.safe_load(f)
if apis is None: if ops is None:
apis = [] ops = []
else: else:
apis = [ ops = [
parse_api_entry(api, "backward_op" if backward else "op") parse_op_entry(op, "backward_op" if backward else "op")
for api in apis for op in ops
] ]
with open(output_path, "wt") as f: with open(output_path, "wt") as f:
yaml.safe_dump(apis, f, default_flow_style=None, sort_keys=False) yaml.safe_dump(ops, f, default_flow_style=None, sort_keys=False)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Parse api yaml into canonical format." description="Parse op yaml into canonical format."
) )
parser.add_argument('--api_yaml_path', type=str, help="api yaml file.") parser.add_argument('--op_yaml_path', type=str, help="op yaml file.")
parser.add_argument( parser.add_argument(
"--output_path", type=str, help="path to save parsed yaml file." "--output_path", type=str, help="path to save parsed yaml file."
) )
parser.add_argument("--backward", action="store_true", default=False) parser.add_argument("--backward", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
main(args.api_yaml_path, args.output_path, args.backward) main(args.op_yaml_path, args.output_path, args.backward)
...@@ -28,7 +28,7 @@ def to_named_dict(items: List[Dict]) -> Dict[str, Dict]: ...@@ -28,7 +28,7 @@ def to_named_dict(items: List[Dict]) -> Dict[str, Dict]:
return named_dict return named_dict
def parse_arg(api_name: str, s: str) -> Dict[str, str]: def parse_arg(op_name: str, s: str) -> Dict[str, str]:
"""parse an argument in following formats: """parse an argument in following formats:
1. typename name 1. typename name
2. typename name = default_value 2. typename name = default_value
...@@ -36,19 +36,19 @@ def parse_arg(api_name: str, s: str) -> Dict[str, str]: ...@@ -36,19 +36,19 @@ def parse_arg(api_name: str, s: str) -> Dict[str, str]:
typename, rest = [item.strip() for item in s.split(" ", 1)] typename, rest = [item.strip() for item in s.split(" ", 1)]
assert ( assert (
len(typename) > 0 len(typename) > 0
), f"The arg typename should not be empty. Please check the args of {api_name} in yaml." ), f"The arg typename should not be empty. Please check the args of {op_name} in yaml."
assert ( assert (
rest.count("=") <= 1 rest.count("=") <= 1
), f"There is more than 1 = in an arg in {api_name}" ), f"There is more than 1 = in an arg in {op_name}"
if rest.count("=") == 1: if rest.count("=") == 1:
name, default_value = [item.strip() for item in rest.split("=", 1)] name, default_value = [item.strip() for item in rest.split("=", 1)]
assert ( assert (
len(name) > 0 len(name) > 0
), f"The arg name should not be empty. Please check the args of {api_name} in yaml." ), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
assert ( assert (
len(default_value) > 0 len(default_value) > 0
), f"The default value should not be empty. Please check the args of {api_name} in yaml." ), f"The default value should not be empty. Please check the args of {op_name} in yaml."
return { return {
"typename": typename, "typename": typename,
"name": name, "name": name,
...@@ -58,17 +58,17 @@ def parse_arg(api_name: str, s: str) -> Dict[str, str]: ...@@ -58,17 +58,17 @@ def parse_arg(api_name: str, s: str) -> Dict[str, str]:
name = rest.strip() name = rest.strip()
assert ( assert (
len(name) > 0 len(name) > 0
), f"The arg name should not be empty. Please check the args of {api_name} in yaml." ), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
return {"typename": typename, "name": name} return {"typename": typename, "name": name}
def parse_input_and_attr( def parse_input_and_attr(
api_name: str, arguments: str op_name: str, arguments: str
) -> Tuple[List, List, Dict, Dict]: ) -> Tuple[List, List, Dict, Dict]:
args_str = arguments.strip() args_str = arguments.strip()
assert args_str.startswith('(') and args_str.endswith(')'), ( assert args_str.startswith('(') and args_str.endswith(')'), (
f"Args declaration should start with '(' and end with ')', " f"Args declaration should start with '(' and end with ')', "
f"please check the args of {api_name} in yaml." f"please check the args of {op_name} in yaml."
) )
args_str = args_str[1:-1] args_str = args_str[1:-1]
args = parse_plain_list(args_str) args = parse_plain_list(args_str)
...@@ -79,13 +79,13 @@ def parse_input_and_attr( ...@@ -79,13 +79,13 @@ def parse_input_and_attr(
met_attr_with_default_value = False met_attr_with_default_value = False
for arg in args: for arg in args:
item = parse_arg(api_name, arg) item = parse_arg(op_name, arg)
typename = item["typename"] typename = item["typename"]
name = item["name"] name = item["name"]
if is_input(typename): if is_input(typename):
assert len(attrs) == 0, ( assert len(attrs) == 0, (
f"The input Tensor should appear before attributes. " f"The input Tensor should appear before attributes. "
f"please check the position of {api_name}:input({name}) " f"please check the position of {op_name}:input({name}) "
f"in yaml." f"in yaml."
) )
inputs.append(item) inputs.append(item)
...@@ -93,16 +93,16 @@ def parse_input_and_attr( ...@@ -93,16 +93,16 @@ def parse_input_and_attr(
if met_attr_with_default_value: if met_attr_with_default_value:
assert ( assert (
"default_value" in item "default_value" in item
), f"{api_name}: Arguments with default value should not precede those without default value" ), f"{op_name}: Arguments with default value should not precede those without default value"
elif "default_value" in item: elif "default_value" in item:
met_attr_with_default_value = True met_attr_with_default_value = True
attrs.append(item) attrs.append(item)
else: else:
raise KeyError(f"{api_name}: Invalid argument type {typename}.") raise KeyError(f"{op_name}: Invalid argument type {typename}.")
return inputs, attrs return inputs, attrs
def parse_output(api_name: str, s: str) -> Dict[str, str]: def parse_output(op_name: str, s: str) -> Dict[str, str]:
"""parse an output, typename or typename(name).""" """parse an output, typename or typename(name)."""
match = re.search( match = re.search(
r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?", r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
...@@ -116,12 +116,12 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]: ...@@ -116,12 +116,12 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]:
size_expr = size_expr[1:-1] if size_expr is not None else None size_expr = size_expr[1:-1] if size_expr is not None else None
assert is_output(typename), ( assert is_output(typename), (
f"Invalid output type: {typename} in api: {api_name}." f"Invalid output type: {typename} in op : {op_name}."
f"Supported types are Tensor and Tensor[]" f"Supported types are Tensor and Tensor[]"
) )
if size_expr is not None: if size_expr is not None:
assert is_vec(typename), ( assert is_vec(typename), (
f"Invalid output size: output {name} in api: {api_name} is " f"Invalid output size: output {name} in op : {op_name} is "
f"not a vector but has size expr" f"not a vector but has size expr"
) )
return {"typename": typename, "name": name, "size": size_expr} return {"typename": typename, "name": name, "size": size_expr}
...@@ -129,11 +129,11 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]: ...@@ -129,11 +129,11 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]:
return {"typename": typename, "name": name} return {"typename": typename, "name": name}
def parse_outputs(api_name: str, outputs: str) -> List[Dict]: def parse_outputs(op_name: str, outputs: str) -> List[Dict]:
outputs = parse_plain_list(outputs, sep=",") outputs = parse_plain_list(outputs, sep=",")
output_items = [] output_items = []
for output in outputs: for output in outputs:
output_items.append(parse_output(api_name, output)) output_items.append(parse_output(op_name, output))
return output_items return output_items
...@@ -157,9 +157,7 @@ def parse_plain_list(s: str, sep=",") -> List[str]: ...@@ -157,9 +157,7 @@ def parse_plain_list(s: str, sep=",") -> List[str]:
return items return items
def parse_kernel( def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
api_name: str, kernel_config: Dict[str, Any]
) -> Dict[str, Any]:
# kernel : # kernel :
# func : [], Kernel functions (example: scale, scale_sr) # func : [], Kernel functions (example: scale, scale_sr)
# param : [], Input params of kernel # param : [], Input params of kernel
...@@ -205,14 +203,14 @@ def parse_kernel( ...@@ -205,14 +203,14 @@ def parse_kernel(
'selected_rows', 'selected_rows',
'sparse_coo', 'sparse_coo',
'sparse_csr', 'sparse_csr',
], f"{api_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'." ], f"{op_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
for item in outputs: for item in outputs:
assert item in [ assert item in [
'dense', 'dense',
'selected_rows', 'selected_rows',
'sparse_coo', 'sparse_coo',
'sparse_csr', 'sparse_csr',
], f"{api_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'." ], f"{op_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
return (inputs, outputs) return (inputs, outputs)
...@@ -225,7 +223,7 @@ def parse_kernel( ...@@ -225,7 +223,7 @@ def parse_kernel(
return kernel return kernel
def parse_inplace(api_name: str, inplace_cfg: str) -> Dict[str, str]: def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
inplace_map = {} inplace_map = {}
inplace_cfg = inplace_cfg.lstrip("(").rstrip(")") inplace_cfg = inplace_cfg.lstrip("(").rstrip(")")
pairs = parse_plain_list(inplace_cfg) pairs = parse_plain_list(inplace_cfg)
...@@ -235,7 +233,7 @@ def parse_inplace(api_name: str, inplace_cfg: str) -> Dict[str, str]: ...@@ -235,7 +233,7 @@ def parse_inplace(api_name: str, inplace_cfg: str) -> Dict[str, str]:
return inplace_map return inplace_map
def parse_invoke(api_name: str, invoke_config: str) -> Dict[str, Any]: def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]:
invoke_config = invoke_config.strip() invoke_config = invoke_config.strip()
func, rest = invoke_config.split("(", 1) func, rest = invoke_config.split("(", 1)
func = func.strip() func = func.strip()
...@@ -245,28 +243,28 @@ def parse_invoke(api_name: str, invoke_config: str) -> Dict[str, Any]: ...@@ -245,28 +243,28 @@ def parse_invoke(api_name: str, invoke_config: str) -> Dict[str, Any]:
def extract_type_and_name(records: List[Dict]) -> List[Dict]: def extract_type_and_name(records: List[Dict]) -> List[Dict]:
"""extract type and name from forward call, it is simpler than forward api.""" """extract type and name from forward call, it is simpler than forward op ."""
extracted = [ extracted = [
{"name": item["name"], "typename": item["typename"]} for item in records {"name": item["name"], "typename": item["typename"]} for item in records
] ]
return extracted return extracted
def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]: def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out) # op_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
result = re.search( result = re.search(
r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)", r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
forward_config, forward_config,
) )
api = result.group("op") op = result.group("op")
outputs = parse_outputs(api_name, result.group("outputs")) outputs = parse_outputs(op_name, result.group("outputs"))
outputs = extract_type_and_name(outputs) outputs = extract_type_and_name(outputs)
inputs, attrs = parse_input_and_attr(api_name, result.group("args")) inputs, attrs = parse_input_and_attr(op_name, result.group("args"))
inputs = extract_type_and_name(inputs) inputs = extract_type_and_name(inputs)
attrs = extract_type_and_name(attrs) attrs = extract_type_and_name(attrs)
forward_cfg = { forward_cfg = {
"name": api, "name": op,
"inputs": inputs, "inputs": inputs,
"attrs": attrs, "attrs": attrs,
"outputs": outputs, "outputs": outputs,
...@@ -274,10 +272,10 @@ def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]: ...@@ -274,10 +272,10 @@ def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]:
return forward_cfg return forward_cfg
def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
api_name = api_entry[name_field] op_name = op_entry[name_field]
inputs, attrs = parse_input_and_attr(api_name, api_entry["args"]) inputs, attrs = parse_input_and_attr(op_name, op_entry["args"])
outputs = parse_outputs(api_name, api_entry["output"]) outputs = parse_outputs(op_name, op_entry["output"])
# validate default value of DataType and DataLayout # validate default value of DataType and DataLayout
for attr in attrs: for attr in attrs:
...@@ -287,14 +285,14 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): ...@@ -287,14 +285,14 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
if typename == "DataType": if typename == "DataType":
assert ( assert (
"DataType" in default_value "DataType" in default_value
), f"invalid DataType default value in {api_name}" ), f"invalid DataType default value in {op_name}"
# remove namespace # remove namespace
default_value = default_value[default_value.find("DataType") :] default_value = default_value[default_value.find("DataType") :]
attr["default_value"] = default_value attr["default_value"] = default_value
elif typename == "DataLayout": elif typename == "DataLayout":
assert ( assert (
"DataLayout" in default_value "DataLayout" in default_value
), f"invalid DataLayout default value in {api_name}" ), f"invalid DataLayout default value in {op_name}"
default_value = default_value[ default_value = default_value[
default_value.find("DataLayout") : default_value.find("DataLayout") :
] ]
...@@ -307,12 +305,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): ...@@ -307,12 +305,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
# add optional tag for every input # add optional tag for every input
for input in inputs: for input in inputs:
input["optional"] = False input["optional"] = False
if "optional" in api_entry: if "optional" in op_entry:
optional_args = parse_plain_list(api_entry["optional"]) optional_args = parse_plain_list(op_entry["optional"])
for name in optional_args: for name in optional_args:
assert ( assert (
name in input_names name in input_names
), f"{api_name} has an optional input: '{name}' which is not an input." ), f"{op_name} has an optional input: '{name}' which is not an input."
for input in inputs: for input in inputs:
if input["name"] in optional_args: if input["name"] in optional_args:
input["optional"] = True input["optional"] = True
...@@ -320,12 +318,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): ...@@ -320,12 +318,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
# add intermediate tag for every output # add intermediate tag for every output
for output in outputs: for output in outputs:
output["intermediate"] = False output["intermediate"] = False
if "intermediate" in api_entry: if "intermediate" in op_entry:
intermediate_outs = parse_plain_list(api_entry["intermediate"]) intermediate_outs = parse_plain_list(op_entry["intermediate"])
for name in intermediate_outs: for name in intermediate_outs:
assert ( assert (
name in output_names name in output_names
), f"{api_name} has an intermediate output: '{name}' which is not an output." ), f"{op_name} has an intermediate output: '{name}' which is not an output."
for output in outputs: for output in outputs:
if output["name"] in intermediate_outs: if output["name"] in intermediate_outs:
output["intermediate"] = True output["intermediate"] = True
...@@ -333,12 +331,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): ...@@ -333,12 +331,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
# add no_need_buffer for every input # add no_need_buffer for every input
for input in inputs: for input in inputs:
input["no_need_buffer"] = False input["no_need_buffer"] = False
if "no_need_buffer" in api_entry: if "no_need_buffer" in op_entry:
no_buffer_args = parse_plain_list(api_entry["no_need_buffer"]) no_buffer_args = parse_plain_list(op_entry["no_need_buffer"])
for name in no_buffer_args: for name in no_buffer_args:
assert ( assert (
name in input_names name in input_names
), f"{api_name} has an no buffer input: '{name}' which is not an input." ), f"{op_name} has an no buffer input: '{name}' which is not an input."
for input in inputs: for input in inputs:
if input["name"] in no_buffer_args: if input["name"] in no_buffer_args:
input["no_need_buffer"] = True input["no_need_buffer"] = True
...@@ -347,34 +345,34 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): ...@@ -347,34 +345,34 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
# TODO(chenfeiyu): data_transform # TODO(chenfeiyu): data_transform
api = { op = {
"name": api_name, "name": op_name,
"inputs": inputs, "inputs": inputs,
"attrs": attrs, "attrs": attrs,
"outputs": outputs, "outputs": outputs,
"no_need_buffer": no_buffer_args, "no_need_buffer": no_buffer_args,
} }
# invokes another api? # invokes another op ?
is_base_api = "invoke" not in api_entry is_base_op = "invoke" not in op_entry
if is_base_api: if is_base_op:
# kernel # kernel
kernel = parse_kernel(api_name, api_entry["kernel"]) kernel = parse_kernel(op_name, op_entry["kernel"])
if kernel["param"] is None: if kernel["param"] is None:
kernel["param"] = input_names + attr_names kernel["param"] = input_names + attr_names
# infer meta # infer meta
infer_meta = parse_infer_meta(api_entry["infer_meta"]) infer_meta = parse_infer_meta(op_entry["infer_meta"])
if infer_meta["param"] is None: if infer_meta["param"] is None:
infer_meta["param"] = copy(kernel["param"]) infer_meta["param"] = copy(kernel["param"])
# inplace # inplace
if "inplace" in api_entry: if "inplace" in op_entry:
inplace_pairs = parse_inplace(api_name, api_entry["inplace"]) inplace_pairs = parse_inplace(op_name, op_entry["inplace"])
else: else:
inplace_pairs = None inplace_pairs = None
api.update( op.update(
{ {
"infer_meta": infer_meta, "infer_meta": infer_meta,
"kernel": kernel, "kernel": kernel,
...@@ -383,47 +381,47 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): ...@@ -383,47 +381,47 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
) )
else: else:
# invoke # invoke
invoke = parse_invoke(api_name, api_entry["invoke"]) invoke = parse_invoke(op_name, op_entry["invoke"])
api["invoke"] = invoke op["invoke"] = invoke
# backward # backward
if "backward" in api_entry: if "backward" in op_entry:
backward = api_entry["backward"] backward = op_entry["backward"]
else: else:
backward = None backward = None
api["backward"] = backward op["backward"] = backward
# forward for backward_apis # forward for backward_ops
is_backward_api = name_field == "backward_op" is_backward_op = name_field == "backward_op"
if is_backward_api: if is_backward_op:
if "forward" in api_entry: if "forward" in op_entry:
forward = parse_forward(api_name, api_entry["forward"]) forward = parse_forward(op_name, op_entry["forward"])
# validate_fb # validate_fb
validate_backward_inputs( validate_backward_inputs(
api_name, forward["inputs"], forward["outputs"], inputs op_name, forward["inputs"], forward["outputs"], inputs
) )
validate_backward_attrs(api_name, forward["attrs"], attrs) validate_backward_attrs(op_name, forward["attrs"], attrs)
validate_backward_outputs(api_name, forward["inputs"], outputs) validate_backward_outputs(op_name, forward["inputs"], outputs)
else: else:
forward = None forward = None
api["forward"] = forward op["forward"] = forward
return api return op
def validate_backward_attrs(api, forward_attrs, backward_attrs): def validate_backward_attrs(op, forward_attrs, backward_attrs):
if len(forward_attrs) >= len(backward_attrs): if len(forward_attrs) >= len(backward_attrs):
return return
num_exceptional_attrs = len(backward_attrs) - len(forward_attrs) num_exceptional_attrs = len(backward_attrs) - len(forward_attrs)
# this is a not-that-clean trick to allow backward api to has more attrs # this is a not-that-clean trick to allow backward op to has more attrs
# than the forward api, as long as they all have default value # than the forward op , as long as they all have default value
for i in range(-num_exceptional_attrs, 0): for i in range(-num_exceptional_attrs, 0):
assert ( assert (
"default_value" in backward_attrs[i] "default_value" in backward_attrs[i]
), f"{api} has exceptional attr without default value" ), f"{op } has exceptional attr without default value"
def validate_backward_inputs( def validate_backward_inputs(
api, forward_inputs, forward_outputs, backward_inputs op, forward_inputs, forward_outputs, backward_inputs
): ):
foward_input_names = [item["name"] for item in forward_inputs] foward_input_names = [item["name"] for item in forward_inputs]
forward_output_names = [item["name"] for item in forward_outputs] forward_output_names = [item["name"] for item in forward_outputs]
...@@ -431,47 +429,47 @@ def validate_backward_inputs( ...@@ -431,47 +429,47 @@ def validate_backward_inputs(
assert len(backward_input_names) <= len(foward_input_names) + 2 * len( assert len(backward_input_names) <= len(foward_input_names) + 2 * len(
forward_output_names forward_output_names
), f"{api} has too many inputs." ), f"{op } has too many inputs."
def validate_backward_outputs(api, forward_inputs, backward_outputs): def validate_backward_outputs(op, forward_inputs, backward_outputs):
assert len(backward_outputs) <= len( assert len(backward_outputs) <= len(
forward_inputs forward_inputs
), f"{api} has too many outputs" ), f"{op } has too many outputs"
def cross_validate(apis): def cross_validate(ops):
for name, api in apis.items(): for name, op in ops.items():
if "forward" in api: if "forward" in op:
fw_call = api["forward"] fw_call = op["forward"]
fw_name = fw_call["name"] fw_name = fw_call["name"]
if fw_name not in apis: if fw_name not in ops:
print( print(
f"Something Wrong here, this backward api({name})'s forward api({fw_name}) does not exist." f"Something Wrong here, this backward op ({name})'s forward op ({fw_name}) does not exist."
) )
else: else:
fw_api = apis[fw_name] fw_op = ops[fw_name]
if "backward" not in fw_api or fw_api["backward"] is None: if "backward" not in fw_op or fw_op["backward"] is None:
print( print(
f"Something Wrong here, {name}'s forward api({fw_name}) does not claim {name} as its backward." f"Something Wrong here, {name}'s forward op ({fw_name}) does not claim {name} as its backward."
) )
else: else:
assert ( assert (
fw_api["backward"] == name fw_op["backward"] == name
), f"{name}: backward and forward name mismatch" ), f"{name}: backward and forward name mismatch"
assert len(fw_call["inputs"]) <= len( assert len(fw_call["inputs"]) <= len(
fw_api["inputs"] fw_op["inputs"]
), f"{name}: forward call has more inputs than the api" ), f"{name}: forward call has more inputs than the op "
for (input, input_) in zip(fw_call["inputs"], fw_api["inputs"]): for (input, input_) in zip(fw_call["inputs"], fw_op["inputs"]):
assert ( assert (
input["typename"] == input_["typename"] input["typename"] == input_["typename"]
), f"type mismatch in {name} and {fw_name}" ), f"type mismatch in {name} and {fw_name}"
assert len(fw_call["attrs"]) <= len( assert len(fw_call["attrs"]) <= len(
fw_api["attrs"] fw_op["attrs"]
), f"{name}: forward call has more attrs than the api" ), f"{name}: forward call has more attrs than the op "
for (attr, attr_) in zip(fw_call["attrs"], fw_api["attrs"]): for (attr, attr_) in zip(fw_call["attrs"], fw_op["attrs"]):
if attr["typename"] == "Scalar": if attr["typename"] == "Scalar":
# special case for Scalar, fw_call can omit the type # special case for Scalar, fw_call can omit the type
assert re.match( assert re.match(
...@@ -483,10 +481,10 @@ def cross_validate(apis): ...@@ -483,10 +481,10 @@ def cross_validate(apis):
), f"type mismatch in {name} and {fw_name}" ), f"type mismatch in {name} and {fw_name}"
assert len(fw_call["outputs"]) == len( assert len(fw_call["outputs"]) == len(
fw_api["outputs"] fw_op["outputs"]
), f"{name}: forward call has more outputs than the api" ), f"{name}: forward call has more outputs than the op "
for (output, output_) in zip( for (output, output_) in zip(
fw_call["outputs"], fw_api["outputs"] fw_call["outputs"], fw_op["outputs"]
): ):
assert ( assert (
output["typename"] == output_["typename"] output["typename"] == output_["typename"]
......
{% from "operator_utils.c.j2" import name_map, register_name_map, register_base_kernel_name %} {% from "operator_utils.c.j2" import name_map, register_name_map, register_base_kernel_name %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. // this file is generated by paddle/phi/op/yaml/generator/generate_op.py, do not edit.
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
namespace phi { namespace phi {
{% for api in apis %} {% for op in ops %}
{% if api is base_api %} {% if op is base_op %}
{{name_map(api)}} {{name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% for api in backward_apis %} {% for op in backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{name_map(api)}} {{name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace phi } // namespace phi
{% for api in apis + backward_apis %} {% for op in ops + backward_ops %}
{% if api["name"] != api["op_name"] %} {% if op["name"] != op["op_name"] %}
{{register_base_kernel_name(api)}} {{register_base_kernel_name(op)}}
{% endif %} {% endif %}
{% if api is base_api %} {% if op is base_op %}
{{register_name_map(api)}} {{register_name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -18,32 +18,32 @@ namespace operators { ...@@ -18,32 +18,32 @@ namespace operators {
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
{% for api in apis %} {% for op in ops %}
{% if api is base_api %} {% if op is base_op %}
{{op_maker(api)}} {{op_maker(op)}}
{{operator(api)}} {{operator(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% for api in backward_apis %} {% for op in backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}} {{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
{{operator(api)}} {{operator(op)}}
{% else %} {% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}} {{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
{% for api in apis + backward_apis %} {% for op in ops + backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{register_op_with_components(api)}} {{register_op_with_components(op)}}
{{register_op_version(api)}} {{register_op_version(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{# ----------------------------- op maker ----------------------------------- #} {# ----------------------------- op maker ----------------------------------- #}
{% macro op_maker(api) %} {% macro op_maker(op) %}
{% set api_name = api["op_name"] %} {% set op_name = op["op_name"] %}
class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker { class {{op_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
{% filter indent(4, True) %} {% filter indent(4, True) %}
{% for input in api["inputs"] %} {% for input in op["inputs"] %}
{{add_input(loop.index0, input, api_name)}}; {{add_input(loop.index0, input, op_name)}};
{% endfor %} {% endfor %}
{% for output in api["outputs"] %} {% for output in op["outputs"] %}
{{add_output(loop.index0, output, api_name)}}; {{add_output(loop.index0, output, op_name)}};
{% endfor %} {% endfor %}
{% for attr in api["attrs"] %} {% for attr in op["attrs"] %}
{% if attr["name"] in api["kernel"]["param"] %} {% if attr["name"] in op["kernel"]["param"] %}
{{add_attr(loop.index0, attr, api_name)}}; {{add_attr(loop.index0, attr, op_name)}};
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% endfilter %} {% endfilter %}
AddComment(R"DOC( AddComment(R"DOC(
TODO: Documentation of {{api_name}} op. TODO: Documentation of {{op_name}} op.
)DOC"); )DOC");
} }
}; };
...@@ -76,7 +76,7 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty ...@@ -76,7 +76,7 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty
{%- endif %} {%- endif %}
{%- endmacro %} {%- endmacro %}
{# process default value for attributes, some attribute has different types and different default values in api & opmaker #} {# process default value for attributes, some attribute has different types and different default values in op & opmaker #}
{% macro process_default_value(attr) %}{# inline #} {% macro process_default_value(attr) %}{# inline #}
{% set default_value = attr["default_value"] %} {% set default_value = attr["default_value"] %}
{% set typename = attr["typename"] %} {% set typename = attr["typename"] %}
...@@ -97,22 +97,22 @@ static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}} ...@@ -97,22 +97,22 @@ static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}}
{# --------------------------------------- name mapping ---------------------------------------------- #} {# --------------------------------------- name mapping ---------------------------------------------- #}
{% macro name_map(api) %} {% macro name_map(op) %}
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %} {% set kernel_args = op["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}}; {{get_input_list(op["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs; paddle::small_vector<const char*> attrs;
{% for attr in api["attrs"]%} {% for attr in op["attrs"]%}
{% filter indent(2)%} {% filter indent(2)%}
{{get_an_attr(attr)}}; {{get_an_attr(attr)}};
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(api["outputs"], kernel_args)}}; {{get_output_list(op["outputs"], kernel_args)}};
{% if api["kernel"]["func"] | length == 1 %} {% if op["kernel"]["func"] | length == 1 %}
KernelSignature sig("{{api["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs)); KernelSignature sig("{{op["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs));
return sig; return sig;
{% else %}{# it has kernel for selected rows #} {% else %}{# it has kernel for selected rows #}
const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{api["kernel"]["func"][1]}}" : "{{api["kernel"]["func"][0]}}"; const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{op["kernel"]["func"][1]}}" : "{{op["kernel"]["func"][0]}}";
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs)); KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig; return sig;
{%endif%} {%endif%}
...@@ -121,9 +121,9 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu ...@@ -121,9 +121,9 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu
/* /*
****************************************************************** ******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py' NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping: All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}} {{op | cartesian_prod_mapping}}
****************************************************************** ******************************************************************
*/ */
{% endmacro %} {% endmacro %}
...@@ -152,20 +152,20 @@ ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}} ...@@ -152,20 +152,20 @@ ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- endfor %} {%- endfor %}
{%- endmacro %} {%- endmacro %}
{% macro sparse_op_name_map(api) %} {% macro sparse_op_name_map(op) %}
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %} {% set kernel_args = op["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}}; {{get_input_list(op["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs; paddle::small_vector<const char*> attrs;
{% for attr in api["attrs"]%} {% for attr in op["attrs"]%}
{% filter indent(2)%} {% filter indent(2)%}
{{get_an_attr(attr)}}; {{get_an_attr(attr)}};
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(api["outputs"], kernel_args)}}; {{get_output_list(op["outputs"], kernel_args)}};
const char* kernel_name = "unregistered"; const char* kernel_name = "unregistered";
{{get_kernel_dispatch(api["inputs"], api["kernel"])}} {{get_kernel_dispatch(op["inputs"], op["kernel"])}}
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs)); KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig; return sig;
} }
...@@ -173,19 +173,19 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu ...@@ -173,19 +173,19 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu
/* /*
****************************************************************** ******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py' NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping: All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}} {{op | cartesian_prod_mapping}}
****************************************************************** ******************************************************************
*/ */
{% endmacro %} {% endmacro %}
{% macro register_base_kernel_name(api) %} {% macro register_base_kernel_name(op) %}
PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}}); PD_REGISTER_BASE_KERNEL_NAME({{op["op_name"]}}, {{op["name"]}});
{%- endmacro %} {%- endmacro %}
{% macro register_name_map(api) %} {% macro register_name_map(op) %}
PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["op_name"] | to_pascal_case}}OpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN({{op["op_name"]}}, phi::{{op["op_name"] | to_pascal_case}}OpArgumentMapping);
{%- endmacro %} {%- endmacro %}
{% macro get_input_list(inputs, kernel_args) %}{# inline #} {% macro get_input_list(inputs, kernel_args) %}{# inline #}
...@@ -228,14 +228,14 @@ paddle::small_vector<const char*> outputs { ...@@ -228,14 +228,14 @@ paddle::small_vector<const char*> outputs {
} }
{%- endmacro %} {%- endmacro %}
{% macro get_expected_kernel(api) %} {% macro get_expected_kernel(op) %}
{% set kernel = api["kernel"] %} {% set kernel = op["kernel"] %}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
{%if kernel["data_type"] is not none %}{# data type ---------------------------------#} {%if kernel["data_type"] is not none %}{# data type ---------------------------------#}
{% if kernel["data_type"]["candidates"] | length == 1 %} {% if kernel["data_type"]["candidates"] | length == 1 %}
{% set data_type_arg = kernel["data_type"]["candidates"][0] %} {% set data_type_arg = kernel["data_type"]["candidates"][0] %}
{% set inputs = api["inputs"] | map(attribute="name") | list %} {% set inputs = op["inputs"] | map(attribute="name") | list %}
{% if data_type_arg in inputs %} {% if data_type_arg in inputs %}
auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}}); auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}});
{% else %}{# it is an attribute and probably named dtype#} {% else %}{# it is an attribute and probably named dtype#}
...@@ -254,68 +254,68 @@ framework::OpKernelType GetExpectedKernelType( ...@@ -254,68 +254,68 @@ framework::OpKernelType GetExpectedKernelType(
{% endmacro %} {% endmacro %}
{# --------------------------------------- operator ---------------------------------------------- #} {# --------------------------------------- operator ---------------------------------------------- #}
{% macro operator(api) %} {% macro operator(op) %}
class {{api["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKernel { class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
{# ----------- get expected kernel type function -------------------------- #} {# ----------- get expected kernel type function -------------------------- #}
{% set kernel = api["kernel"] %} {% set kernel = op["kernel"] %}
{% if kernel["data_type"] is not none %} {% if kernel["data_type"] is not none %}
protected: protected:
{% filter indent(2, True)%} {% filter indent(2, True)%}
{{get_expected_kernel(api)}} {{get_expected_kernel(op)}}
{% endfilter %} {% endfilter %}
{% endif %} {% endif %}
}; };
DECLARE_INFER_SHAPE_FUNCTOR({{api["op_name"]}}, {{api["op_name"] | to_pascal_case}}InferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor,
PD_INFER_META(phi::{{api["infer_meta"]["func"]}})); PD_INFER_META(phi::{{op["infer_meta"]["func"]}}));
{# inplace inferer #} {# inplace inferer #}
{% if api["inplace"] is not none %} {% if op["inplace"] is not none %}
{% set inplace_map %} {% set inplace_map %}
{% for source, target in api["inplace"].items() %} {% for source, target in op["inplace"].items() %}
{{"{"}}{{target | to_opmaker_name}}, {{source | to_opmaker_name}}{{"}"}}{{", " if not loop.last}} {{"{"}}{{target | to_opmaker_name}}, {{source | to_opmaker_name}}{{"}"}}{{", " if not loop.last}}
{%- endfor %} {%- endfor %}
{%- endset %} {%- endset %}
DECLARE_INPLACE_OP_INFERER({{api["op_name"] | to_pascal_case}}InplaceInferer, DECLARE_INPLACE_OP_INFERER({{op["op_name"] | to_pascal_case}}InplaceInferer,
{{inplace_map}}); {{inplace_map}});
{% endif %} {% endif %}
{# no_need_buffer inferer #} {# no_need_buffer inferer #}
{% if api["no_need_buffer"] is not none %} {% if op["no_need_buffer"] is not none %}
DECLARE_NO_NEED_BUFFER_VARS_INFERER({{api["op_name"] | to_pascal_case}}NoNeedBufferVarInferer, DECLARE_NO_NEED_BUFFER_VARS_INFERER({{op["op_name"] | to_pascal_case}}NoNeedBufferVarInferer,
{{api["no_need_buffer"] | map("to_opmaker_name") | join(", ")}}); {{op["no_need_buffer"] | map("to_opmaker_name") | join(", ")}});
{% endif %} {% endif %}
{% endmacro%} {% endmacro%}
{% macro register_op_with_components(api) %} {% macro register_op_with_components(op) %}
{% set name = api["op_name"] %} {% set name = op["op_name"] %}
REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in api %}{# it is a forward api #} {% if not "forward" in op %}{# it is a forward op #}
ops::{{name | to_pascal_case}}OpMaker, ops::{{name | to_pascal_case}}OpMaker,
{% endif %} {% endif %}
{% if "backward" in api and api["backward"] is not none %}{# backward #} {% if "backward" in op and op["backward"] is not none %}{# backward #}
{% set backward_name = api["backward"] %} {% set backward_name = op["backward"] %}
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>,
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>,
{% else %} {% else %}
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
{% endif %} {% endif %}
{% if api is supports_inplace %}{# inplace#} {% if op is supports_inplace %}{# inplace#}
ops::{{name | to_pascal_case}}InplaceInferer, ops::{{name | to_pascal_case}}InplaceInferer,
{% endif %} {% endif %}
{% if api is supports_no_need_buffer %}{# no_need_buffer #} {% if op is supports_no_need_buffer %}{# no_need_buffer #}
ops::{{name | to_pascal_case}}NoNeedBufferVarInferer, ops::{{name | to_pascal_case}}NoNeedBufferVarInferer,
{% endif %} {% endif %}
ops::{{name | to_pascal_case}}InferShapeFunctor); ops::{{name | to_pascal_case}}InferShapeFunctor);
{% endmacro %} {% endmacro %}
{% macro register_op_version(api) %} {% macro register_op_version(op) %}
{% if "version" in api %} {% if "version" in op %}
{% set name = api["op_name"] %} {% set name = op["op_name"] %}
REGISTER_OP_VERSION({{name}}) REGISTER_OP_VERSION({{name}})
{% for checkpoint in api["version"]%} {% for checkpoint in op["version"]%}
.AddCheckpoint( .AddCheckpoint(
R"ROC({{checkpoint["checkpoint"]}})ROC", R"ROC({{checkpoint["checkpoint"]}})ROC",
paddle::framework::compatible::OpVersionDesc() paddle::framework::compatible::OpVersionDesc()
...@@ -354,14 +354,14 @@ REGISTER_OP_VERSION({{name}}) ...@@ -354,14 +354,14 @@ REGISTER_OP_VERSION({{name}})
{# --------------------------------------- backward op maker ---------------------------------------------- #} {# --------------------------------------- backward op maker ---------------------------------------------- #}
{% macro backward_op_maker(api, forward_api) %} {% macro backward_op_maker(op, forward_op ) %}
{% set name = api["op_name"] %} {% set name = op["op_name"] %}
{% set forward_input_names = api["forward"]["inputs"] | map(attribute="name") | list %} {% set forward_input_names = op["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_output_names = api["forward"]["outputs"] | map(attribute="name") | list %} {% set forward_output_names = op["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_attr_names = api["forward"]["attrs"] | map(attribute="name") | list %} {% set forward_attr_names = op["forward"]["attrs"] | map(attribute="name") | list %}
{% set forward_input_orig_names = forward_api["inputs"] | map(attribute="name") | list %} {% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %}
{% set forward_output_orig_names = forward_api["outputs"] | map(attribute="name") | list %} {% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %}
{% set forward_attr_orig_names = forward_api["attrs"] | map(attribute="name") | list %} {% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %}
template <typename T> template <typename T>
class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> { class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -371,7 +371,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -371,7 +371,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
void Apply(GradOpPtr<T> grad_op) const override { void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("{{name}}"); grad_op->SetType("{{name}}");
{% for input in api["inputs"] %} {% for input in op["inputs"] %}
grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward( grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["name"], input["name"],
forward_input_names, forward_input_names,
...@@ -380,7 +380,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -380,7 +380,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
forward_output_orig_names)}}); forward_output_orig_names)}});
{% endfor %} {% endfor %}
{% for output in api["outputs"] %} {% for output in op["outputs"] %}
grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward( grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["name"], output["name"],
forward_input_names, forward_input_names,
...@@ -390,7 +390,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -390,7 +390,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% endfor %} {% endfor %}
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
{% for attr in api["attrs"] %} {% for attr in op["attrs"] %}
{% set attr_name = attr["name"] %} {% set attr_name = attr["name"] %}
{% if attr_name in forward_attr_names %} {% if attr_name in forward_attr_names %}
{% if attr["typename"] == "IntArray" %} {% if attr["typename"] == "IntArray" %}
......
...@@ -5,20 +5,20 @@ ...@@ -5,20 +5,20 @@
namespace phi { namespace phi {
{% for api in apis %} {% for op in ops %}
{% if api is base_api %} {% if op is base_op %}
{{sparse_op_name_map(api)}} {{sparse_op_name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% for api in backward_apis %} {% for op in backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{sparse_op_name_map(api)}} {{sparse_op_name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace phi } // namespace phi
{% for api in apis + backward_apis %} {% for op in ops + backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{register_name_map(api)}} {{register_name_map(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -19,31 +19,31 @@ namespace operators { ...@@ -19,31 +19,31 @@ namespace operators {
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
{% for api in apis %} {% for op in ops %}
{% if api is base_api %} {% if op is base_op %}
{{op_maker(api)}} {{op_maker(op)}}
{{operator(api)}} {{operator(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% for api in backward_apis %} {% for op in backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}} {{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
{{operator(api)}} {{operator(op)}}
{% else %} {% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}} {{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
{% for api in apis + backward_apis %} {% for op in ops + backward_ops %}
{% if api is base_api %} {% if op is base_op %}
{{register_op_with_components(api)}} {{register_op_with_components(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -41,20 +41,20 @@ def is_initializer_list(s): ...@@ -41,20 +41,20 @@ def is_initializer_list(s):
return s == "{}" return s == "{}"
def is_base_api(api): def is_base_op(op):
return "kernel" in api and "infer_meta" in api return "kernel" in op and "infer_meta" in op
def supports_selected_rows_kernel(api): def supports_selected_rows_kernel(op):
return is_base_api(api) and len(api["kernel"]["func"]) == 2 return is_base_op(op) and len(op["kernel"]["func"]) == 2
def supports_inplace(api): def supports_inplace(op):
return api['inplace'] is not None return op['inplace'] is not None
def supports_no_need_buffer(api): def supports_no_need_buffer(op):
for input in api["inputs"]: for input in op["inputs"]:
if input["no_need_buffer"]: if input["no_need_buffer"]:
return True return True
return False return False
...@@ -15,8 +15,6 @@ add_subdirectory(backends) ...@@ -15,8 +15,6 @@ add_subdirectory(backends)
add_subdirectory(kernels) add_subdirectory(kernels)
# phi infermeta # phi infermeta
add_subdirectory(infermeta) add_subdirectory(infermeta)
# phi operator definitions
add_subdirectory(ops)
# phi tools # phi tools
add_subdirectory(tools) add_subdirectory(tools)
# phi tests # phi tests
...@@ -36,7 +34,6 @@ set(PHI_DEPS ...@@ -36,7 +34,6 @@ set(PHI_DEPS
arg_map_context arg_map_context
infermeta infermeta
lod_utils lod_utils
op_compat_infos
sparse_csr_tensor sparse_csr_tensor
sparse_coo_tensor sparse_coo_tensor
string_tensor string_tensor
......
...@@ -94,204 +94,10 @@ set(wrapped_infermeta_header_file ...@@ -94,204 +94,10 @@ set(wrapped_infermeta_header_file
set(wrapped_infermeta_source_file set(wrapped_infermeta_source_file
${CMAKE_SOURCE_DIR}/paddle/phi/infermeta/generated.cc) ${CMAKE_SOURCE_DIR}/paddle/phi/infermeta/generated.cc)
# op extra info file
set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/ops_extra_info_gen.py)
set(op_compat_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc)
if(NOT PYTHONINTERP_FOUND) if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED) find_package(PythonInterp REQUIRED)
endif() endif()
# install extra dependencies
if(${PYTHON_VERSION_STRING} VERSION_LESS "3.6.2")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml
typing-extensions>=4.1.1 jinja2==2.11.3)
else()
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml jinja2
typing-extensions)
endif()
# parse apis
set(parsed_api_dir ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/parsed_apis)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_sparse_ops_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc)
set(generated_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc)
set(generated_sparse_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc)
message(
"parse api yamls:
- ${api_yaml_file}
- ${legacy_api_yaml_file}
- ${bw_api_yaml_file}
- ${legacy_bw_api_yaml_file}")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_api_dir}
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./ops.yaml
--output_path ./parsed_apis/ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_ops.yaml --output_path ./parsed_apis/legacy_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./backward.yaml
--output_path ./parsed_apis/backward_ops.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_backward.yaml --output_path
./parsed_apis/legacy_backward_ops.parsed.yaml --backward
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_ops.yaml --output_path ./parsed_apis/sparse_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_backward.yaml --output_path
./parsed_apis/sparse_backward.parsed.yaml --backward RESULTS_VARIABLE
_results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "api yaml parsing failed, exiting.")
endif()
endforeach()
# validation of api yamls
message("validate api yaml:
- ${parsed_api_dir}/ops.parsed.yaml
- ${parsed_api_dir}/backward_ops.parsed.yaml")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/ops.parsed.yaml ./parsed_apis/legacy_ops.parsed.yaml
--backward_yaml_paths ./parsed_apis/backward_ops.parsed.yaml
./parsed_apis/legacy_backward_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/sparse_ops.parsed.yaml --backward_yaml_paths
./parsed_apis/sparse_backward.parsed.yaml
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
endforeach()
# code generation for op, op makers, and argument mapping functions
message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
)
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/generate_op.py --ops_yaml_path
./parsed_apis/ops.parsed.yaml --backward_yaml_path
./parsed_apis/backward_ops.parsed.yaml --op_version_yaml_path
op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path
"${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
COMMAND
${PYTHON_EXECUTABLE} generator/generate_sparse_op.py --ops_yaml_path
./parsed_apis/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_apis/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_op_path}.tmp" "${generated_op_path}")
message("copy if different ${generated_op_path}.tmp ${generated_op_path}")
elseif(EXISTS "${generated_op_path}.tmp")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_op_path}.tmp"
"${generated_op_path}")
message("copy ${generated_op_path}.tmp ${generated_op_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f "${generated_op_path}")
message("remove ${generated_op_path}")
endif()
if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS
"${generated_sparse_ops_path}")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}")
message(
"copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}"
)
elseif(EXISTS "${generated_sparse_ops_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp"
"${generated_sparse_ops_path}")
message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_ops_path}")
message("remove ${generated_sparse_ops_path}")
endif()
if(EXISTS "${generated_argument_mapping_path}.tmp"
AND EXISTS "${generated_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy if different ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
elseif(EXISTS "${generated_argument_mapping_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_argument_mapping_path}")
message("remove ${generated_argument_mapping_path}")
endif()
if(EXISTS "${generated_sparse_argument_mapping_path}.tmp"
AND EXISTS "${generated_sparse_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_argument_mapping_path}")
message("remove ${generated_sparse_argument_mapping_path}")
endif()
# generate ops extra info
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path
${op_compat_yaml_file} --ops_extra_info_path ${ops_extra_info_file})
message("generate ${ops_extra_info_file}")
# generate forward api # generate forward api
add_custom_command( add_custom_command(
OUTPUT ${api_header_file} ${api_source_file} OUTPUT ${api_header_file} ${api_source_file}
......
set(op_utils_header
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h.tmp
CACHE INTERNAL "op_args_fns.cc file")
set(op_utils_header_final
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h)
file(
WRITE ${op_utils_header}
"// Generated by the paddle/phi/ops/compat/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${op_utils_header}
"#include \"paddle/phi/core/compat/op_utils.h\"\n\n")
# Automatically generate the registration code of all arg map functions
# and compile the corresponding target to avoid frequent code conflicts
# when writing to same file
register_op_utils(op_compat_infos DEPS op_utils)
copy_if_different(${op_utils_header} ${op_utils_header_final})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册