未验证 提交 f369b2b1 编写于 作者: C Chen Weihang 提交者: GitHub

[PHI decoupling] Move fluid op generator into fluid (#47714)

* move fluid op generator into fluid

* remove parsed op

* resolve sig undef error

* append python interp find logic

* remove dup code
上级 0a051297
......@@ -75,8 +75,7 @@ paddle/fluid/operators/generated_op.cc
paddle/fluid/operators/generated_sparse_op.cc
paddle/phi/ops/compat/generated_sig.cc
paddle/phi/ops/compat/generated_sparse_sig.cc
paddle/phi/api/yaml/parsed_apis/
python/paddle/utils/code_gen/
paddle/fluid/operators/generator/parsed_ops/
paddle/fluid/pybind/tmp_eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
......
......@@ -111,7 +111,7 @@ function(kernel_declare TARGET_LIST)
endfunction()
function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content)
file(READ ${TARGET} target_content)
string(
REGEX
MATCH
......@@ -134,13 +134,10 @@ function(register_op_utils TARGET_NAME)
cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
file(
GLOB SIGNATURES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"*_sig.cc")
file(GLOB SIGNATURES "${PADDLE_SOURCE_DIR}/paddle/phi/ops/compat/*_sig.cc")
foreach(target ${SIGNATURES})
append_op_util_declare(${target})
list(APPEND utils_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${target})
list(APPEND utils_srcs ${target})
endforeach()
cc_library(
......
......@@ -503,7 +503,8 @@ if(WITH_XPU)
phi_utils
kernel_factory
infershape_utils
op_utils)
op_utils
op_compat_infos)
else()
cc_library(
operator
......@@ -528,7 +529,8 @@ else()
phi_utils
kernel_factory
infershape_utils
op_utils)
op_utils
op_compat_infos)
endif()
cc_test(
......
include(operators)
add_subdirectory(generator)
# solve "math constants not defined" problems caused by the order of inclusion
# of <cmath> and the definition of macro _USE_MATH_DEFINES
add_definitions(-D_USE_MATH_DEFINES)
......
# phi auto cmake utils
include(phi)
# set yaml file path
set(op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/ops.yaml)
set(legacy_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/legacy_ops.yaml)
set(bw_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/backward.yaml)
set(legacy_bw_op_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/legacy_backward.yaml)
set(sparse_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_ops.yaml)
set(sparse_bw_op_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_backward.yaml)
if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED)
endif()
# install extra dependencies
if(${PYTHON_VERSION_STRING} VERSION_LESS "3.6.2")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml
typing-extensions>=4.1.1 jinja2==2.11.3)
else()
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml jinja2
typing-extensions)
endif()
# parse ops
set(parsed_op_dir
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_sparse_ops_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc)
set(generated_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc)
set(generated_sparse_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc)
message(
"parse op yamls:
- ${op_yaml_file}
- ${legacy_op_yaml_file}
- ${bw_op_yaml_file}
- ${legacy_bw_op_yaml_file}")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir}
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${op_yaml_file}
--output_path ./parsed_ops/ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${legacy_op_yaml_file}
--output_path ./parsed_ops/legacy_ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${bw_op_yaml_file}
--output_path ./parsed_ops/backward_ops.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${legacy_bw_op_yaml_file}
--output_path ./parsed_ops/legacy_backward_ops.parsed.yaml --backward
COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_op_yaml_file}
--output_path ./parsed_ops/sparse_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_bw_op_yaml_file}
--output_path ./parsed_ops/sparse_backward.parsed.yaml --backward
RESULTS_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "op yaml parsing failed, exiting.")
endif()
endforeach()
# validation of op yamls
message("validate op yaml:
- ${parsed_op_dir}/ops.parsed.yaml
- ${parsed_op_dir}/backward_ops.parsed.yaml")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
./parsed_ops/ops.parsed.yaml ./parsed_ops/legacy_ops.parsed.yaml
--backward_yaml_paths ./parsed_ops/backward_ops.parsed.yaml
./parsed_ops/legacy_backward_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
./parsed_ops/sparse_ops.parsed.yaml --backward_yaml_paths
./parsed_ops/sparse_backward.parsed.yaml
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
endforeach()
# code generation for op, op makers, and argument mapping functions
message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
)
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
${PYTHON_EXECUTABLE} generate_op.py --ops_yaml_path
./parsed_ops/ops.parsed.yaml --backward_yaml_path
./parsed_ops/backward_ops.parsed.yaml --op_version_yaml_path
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
COMMAND
${PYTHON_EXECUTABLE} generate_sparse_op.py --ops_yaml_path
./parsed_ops/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_ops/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_op_path}.tmp" "${generated_op_path}")
message("copy if different ${generated_op_path}.tmp ${generated_op_path}")
elseif(EXISTS "${generated_op_path}.tmp")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_op_path}.tmp"
"${generated_op_path}")
message("copy ${generated_op_path}.tmp ${generated_op_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f "${generated_op_path}")
message("remove ${generated_op_path}")
endif()
if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS
"${generated_sparse_ops_path}")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}")
message(
"copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}"
)
elseif(EXISTS "${generated_sparse_ops_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp"
"${generated_sparse_ops_path}")
message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_ops_path}")
message("remove ${generated_sparse_ops_path}")
endif()
if(EXISTS "${generated_argument_mapping_path}.tmp"
AND EXISTS "${generated_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy if different ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
elseif(EXISTS "${generated_argument_mapping_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_argument_mapping_path}")
message("remove ${generated_argument_mapping_path}")
endif()
if(EXISTS "${generated_sparse_argument_mapping_path}.tmp"
AND EXISTS "${generated_sparse_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_argument_mapping_path}")
message("remove ${generated_sparse_argument_mapping_path}")
endif()
# op extra info file
set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/ops_extra_info_gen.py)
set(op_compat_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc)
# generate ops extra info
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path
${op_compat_yaml_file} --ops_extra_info_path ${ops_extra_info_file})
message("generate ${ops_extra_info_file}")
set(op_utils_header
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h.tmp
CACHE INTERNAL "op_args_fns.cc file")
set(op_utils_header_final
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h)
file(
WRITE ${op_utils_header}
"// Generated by the paddle/fluid/operators/generator/CMakeLists.txt. DO NOT EDIT!\n\n"
)
file(APPEND ${op_utils_header}
"#include \"paddle/phi/core/compat/op_utils.h\"\n\n")
# Automatically generate the registration code of all arg map functions
# and compile the corresponding target to avoid frequent code conflicts
# when writing to same file
register_op_utils(op_compat_infos DEPS op_utils)
copy_if_different(${op_utils_header} ${op_utils_header_final})
......@@ -20,35 +20,35 @@ import yaml
from parse_utils import cross_validate, to_named_dict
def main(forward_api_yaml_paths, backward_api_yaml_paths):
apis = {}
for api_yaml_path in chain(forward_api_yaml_paths, backward_api_yaml_paths):
with open(api_yaml_path, "rt", encoding="utf-8") as f:
api_list = yaml.safe_load(f)
if api_list is not None:
apis.update(to_named_dict((api_list)))
def main(forward_op_yaml_paths, backward_op_yaml_paths):
ops = {}
for op_yaml_path in chain(forward_op_yaml_paths, backward_op_yaml_paths):
with open(op_yaml_path, "rt", encoding="utf-8") as f:
op_list = yaml.safe_load(f)
if op_list is not None:
ops.update(to_named_dict((op_list)))
cross_validate(apis)
cross_validate(ops)
if __name__ == "__main__":
current_dir = Path(__file__).parent / "temp"
parser = argparse.ArgumentParser(
description="Parse api yaml into canonical format."
description="Parse op yaml into canonical format."
)
parser.add_argument(
'--forward_yaml_paths',
type=str,
nargs='+',
default=str(current_dir / "api.parsed.yaml"),
help="forward api yaml file.",
default=str(current_dir / "op .parsed.yaml"),
help="forward op yaml file.",
)
parser.add_argument(
'--backward_yaml_paths',
type=str,
nargs='+',
default=str(current_dir / "backward_api.parsed.yaml"),
help="backward api yaml file.",
default=str(current_dir / "backward_op .parsed.yaml"),
help="backward op yaml file.",
)
args = parser.parse_args()
......
......@@ -102,12 +102,12 @@ def to_pascal_case(s):
def to_input_name(s):
"""find input variable name in api yaml for higher order backward api.
"""find input variable name in op yaml for higher order backward op .
x -> dx
x -> d2x
x -> d3x
NOTE: for first order backward api
NOTE: for first order backward op
x -> x_grad
is more common.
"""
......@@ -137,16 +137,14 @@ def cartesian_prod_attrs(attrs):
return combinations
def cartesian_prod_mapping(api):
kernels = api["kernel"]["func"]
def cartesian_prod_mapping(op):
kernels = op["kernel"]["func"]
inputs = [
x["name"] for x in api["inputs"] if x["name"] in api["kernel"]["param"]
x["name"] for x in op["inputs"] if x["name"] in op["kernel"]["param"]
]
inputs = [to_opmaker_name_cstr(input) for input in inputs]
attrs = cartesian_prod_attrs(api["attrs"])
outputs = [
to_opmaker_name_cstr(output["name"]) for output in api["outputs"]
]
attrs = cartesian_prod_attrs(op["attrs"])
outputs = [to_opmaker_name_cstr(output["name"]) for output in op["outputs"]]
def vec(items):
return "{" + ', '.join(items) + "}"
......
......@@ -26,7 +26,7 @@ from filters import (
to_pascal_case,
)
from tests import (
is_base_api,
is_base_op,
is_vec,
is_scalar,
is_initializer_list,
......@@ -51,7 +51,7 @@ env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.tests["base_api"] = is_base_api
env.tests["base_op"] = is_base_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
......@@ -59,126 +59,127 @@ env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer
def restruct_io(api):
api["input_dict"] = to_named_dict(api["inputs"])
api["attr_dict"] = to_named_dict(api["attrs"])
api["output_dict"] = to_named_dict(api["outputs"])
return api
def restruct_io(op):
op["input_dict"] = to_named_dict(op["inputs"])
op["attr_dict"] = to_named_dict(op["attrs"])
op["output_dict"] = to_named_dict(op["outputs"])
return op
# replace name of op and params for OpMaker
def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
def get_api_and_op_name(api_item):
names = api_item.split('(')
def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
def get_op_and_op_name(op_item):
names = op_item.split('(')
if len(names) == 1:
return names[0].strip(), names[0].strip()
else:
return names[0].strip(), names[1].split(')')[0].strip()
def update_api_attr_name(attrs, attrs_alias_map):
def update_op_attr_name(attrs, attrs_alias_map):
for attr_item in attrs:
if attr_item['name'] in attrs_alias_map:
attr_item['name'] = attrs_alias_map[attr_item['name']]
for api_args in api_op_map:
api_name, op_name = get_api_and_op_name(api_args['op'])
if api_name not in forward_api_dict:
for op_args in op_op_map:
new_op_name, op_name = get_op_and_op_name(op_args['op'])
if new_op_name not in forward_op_dict:
continue
forward_api_item = forward_api_dict[api_name]
has_backward = True if forward_api_item['backward'] else False
forward_op_item = forward_op_dict[new_op_name]
has_backward = True if forward_op_item['backward'] else False
if has_backward:
backward_api_item = backward_api_dict[forward_api_item['backward']]
if api_name != op_name:
forward_api_item['op_name'] = op_name
if 'backward' in api_args and has_backward:
backward_op_list = api_args['backward'].split(',')
bw_api_name, bw_op_name = get_api_and_op_name(backward_op_list[0])
forward_api_item['backward'] = bw_op_name
backward_api_item['op_name'] = bw_op_name
backward_op_item = backward_op_dict[forward_op_item['backward']]
if new_op_name != op_name:
forward_op_item['op_name'] = op_name
if 'backward' in op_args and has_backward:
backward_op_list = op_args['backward'].split(',')
_, bw_op_name = get_op_and_op_name(backward_op_list[0])
forward_op_item['backward'] = bw_op_name
backward_op_item['op_name'] = bw_op_name
# for double grad
if len(backward_op_list) > 1:
double_grad_api_name, double_grad_op_name = get_api_and_op_name(
backward_op_list[1]
)
double_grad_item = backward_api_dict[double_grad_api_name]
backward_api_item['backward'] = double_grad_op_name
(
new_double_grad_op_name,
double_grad_op_name,
) = get_op_and_op_name(backward_op_list[1])
double_grad_item = backward_op_dict[new_double_grad_op_name]
backward_op_item['backward'] = double_grad_op_name
double_grad_item['op_name'] = double_grad_op_name
if 'attrs' in api_args:
update_api_attr_name(
double_grad_item['attrs'], api_args['attrs']
if 'attrs' in op_args:
update_op_attr_name(
double_grad_item['attrs'], op_args['attrs']
)
update_api_attr_name(
double_grad_item['forward']['attrs'], api_args['attrs']
update_op_attr_name(
double_grad_item['forward']['attrs'], op_args['attrs']
)
# for triple grad
if len(backward_op_list) > 2:
(
triple_grad_api_name,
new_triple_grad_op_name,
triple_grad_op_name,
) = get_api_and_op_name(backward_op_list[2])
triple_grad_item = backward_api_dict[triple_grad_api_name]
) = get_op_and_op_name(backward_op_list[2])
triple_grad_item = backward_op_dict[new_triple_grad_op_name]
double_grad_item['backward'] = triple_grad_op_name
triple_grad_item['op_name'] = triple_grad_op_name
if 'attrs' in api_args:
update_api_attr_name(
triple_grad_item['attrs'], api_args['attrs']
if 'attrs' in op_args:
update_op_attr_name(
triple_grad_item['attrs'], op_args['attrs']
)
update_api_attr_name(
update_op_attr_name(
triple_grad_item['forward']['attrs'],
api_args['attrs'],
op_args['attrs'],
)
key_set = ['inputs', 'attrs', 'outputs']
args_map = {}
for key in key_set:
if key in api_args:
args_map.update(api_args[key])
for args_item in forward_api_item[key]:
if args_item['name'] in api_args[key]:
args_item['name'] = api_args[key][args_item['name']]
if key in op_args:
args_map.update(op_args[key])
for args_item in forward_op_item[key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
if has_backward:
for args_item in backward_api_item['forward'][key]:
if args_item['name'] in api_args[key]:
args_item['name'] = api_args[key][args_item['name']]
forward_api_item['infer_meta']['param'] = [
for args_item in backward_op_item['forward'][key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
forward_op_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['infer_meta']['param']
for param in forward_op_item['infer_meta']['param']
]
forward_api_item['kernel']['param'] = [
forward_op_item['kernel']['param'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['param']
for param in forward_op_item['kernel']['param']
]
if forward_api_item['kernel']['data_type']:
forward_api_item['kernel']['data_type']['candidates'] = [
if forward_op_item['kernel']['data_type']:
forward_op_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['data_type'][
for param in forward_op_item['kernel']['data_type'][
'candidates'
]
]
if forward_api_item['kernel']['backend']:
forward_api_item['kernel']['backend']['candidates'] = [
if forward_op_item['kernel']['backend']:
forward_op_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['backend']['candidates']
for param in forward_op_item['kernel']['backend']['candidates']
]
if forward_api_item['kernel']['layout']:
forward_api_item['kernel']['layout']['candidates'] = [
if forward_op_item['kernel']['layout']:
forward_op_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['layout']['candidates']
for param in forward_op_item['kernel']['layout']['candidates']
]
if forward_api_item['inplace']:
if forward_op_item['inplace']:
inplace_map = {}
for key, val in forward_api_item['inplace'].items():
for key, val in forward_op_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
forward_api_item['inplace'] = inplace_map
forward_op_item['inplace'] = inplace_map
if has_backward:
for args_item in backward_api_item['inputs']:
for args_item in backward_op_item['inputs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
elif (
......@@ -189,10 +190,10 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
args_map[args_item['name'][:-5]] + '_grad'
)
args_item['name'] = args_map[args_item['name']]
for args_item in backward_api_item['attrs']:
for args_item in backward_op_item['attrs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
for args_item in backward_api_item['outputs']:
for args_item in backward_op_item['outputs']:
if (
args_item['name'].endswith('_grad')
and args_item['name'][:-5] in args_map
......@@ -202,73 +203,73 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
)
args_item['name'] = args_map[args_item['name']]
if 'invoke' in backward_api_item:
backward_api_item['invoke']['args'] = [
if 'invoke' in backward_op_item:
backward_op_item['invoke']['args'] = [
args_map[param.strip()]
if param.strip() in args_map
else param.strip()
for param in backward_api_item['invoke']['args'].split(',')
for param in backward_op_item['invoke']['args'].split(',')
]
continue
backward_api_item['infer_meta']['param'] = [
backward_op_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['infer_meta']['param']
for param in backward_op_item['infer_meta']['param']
]
backward_api_item['kernel']['param'] = [
backward_op_item['kernel']['param'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['param']
for param in backward_op_item['kernel']['param']
]
if backward_api_item['kernel']['data_type']:
backward_api_item['kernel']['data_type']['candidates'] = [
if backward_op_item['kernel']['data_type']:
backward_op_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['data_type'][
for param in backward_op_item['kernel']['data_type'][
'candidates'
]
]
if backward_api_item['kernel']['backend']:
backward_api_item['kernel']['backend']['candidates'] = [
if backward_op_item['kernel']['backend']:
backward_op_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['backend'][
for param in backward_op_item['kernel']['backend'][
'candidates'
]
]
if backward_api_item['kernel']['layout']:
backward_api_item['kernel']['layout']['candidates'] = [
if backward_op_item['kernel']['layout']:
backward_op_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['layout'][
for param in backward_op_item['kernel']['layout'][
'candidates'
]
]
if backward_api_item['no_need_buffer']:
backward_api_item['no_need_buffer'] = [
if backward_op_item['no_need_buffer']:
backward_op_item['no_need_buffer'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['no_need_buffer']
for param in backward_op_item['no_need_buffer']
]
if backward_api_item['inplace']:
if backward_op_item['inplace']:
inplace_map = {}
for key, val in backward_api_item['inplace'].items():
for key, val in backward_op_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
backward_api_item['inplace'] = inplace_map
backward_op_item['inplace'] = inplace_map
def process_invoke_op(forward_api_dict, backward_api_dict):
for bw_api in backward_api_dict.values():
if 'invoke' in bw_api:
invoke_op = bw_api['invoke']['func']
args_list = bw_api['invoke']['args']
def process_invoke_op(forward_op_dict, backward_op_dict):
for bw_op in backward_op_dict.values():
if 'invoke' in bw_op:
invoke_op = bw_op['invoke']['func']
args_list = bw_op['invoke']['args']
args_index = 0
if invoke_op in forward_api_dict:
reuse_op = forward_api_dict[invoke_op]
bw_api['invoke']['inputs'] = []
bw_api['invoke']['attrs'] = []
bw_api['invoke']['outputs'] = []
if invoke_op in forward_op_dict:
reuse_op = forward_op_dict[invoke_op]
bw_op['invoke']['inputs'] = []
bw_op['invoke']['attrs'] = []
bw_op['invoke']['outputs'] = []
for input_item in reuse_op['inputs']:
bw_api['invoke']['inputs'].append(
bw_op['invoke']['inputs'].append(
{
'name': input_item['name'],
'value': args_list[args_index],
......@@ -279,20 +280,20 @@ def process_invoke_op(forward_api_dict, backward_api_dict):
if args_index < len(args_list):
attr_value = (
f"this->GetAttr(\"{args_list[args_index]}\")"
if args_list[args_index] in bw_api['attr_dict']
if args_list[args_index] in bw_op['attr_dict']
else args_list[args_index]
)
bw_api['invoke']['attrs'].append(
bw_op['invoke']['attrs'].append(
{'name': attr['name'], 'value': attr_value}
)
args_index = args_index + 1
else:
break
for idx, output_item in enumerate(reuse_op['outputs']):
bw_api['invoke']['outputs'].append(
bw_op['invoke']['outputs'].append(
{
'name': output_item['name'],
'value': bw_api['outputs'][idx]['name'],
'value': bw_op['outputs'][idx]['name'],
}
)
......@@ -306,47 +307,47 @@ def main(
output_arg_map_path,
):
with open(ops_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
ops = yaml.safe_load(f)
ops = [restruct_io(op) for op in ops]
forward_op_dict = to_named_dict(ops)
with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
backward_ops = yaml.safe_load(f)
backward_ops = [restruct_io(op) for op in backward_ops]
backward_op_dict = to_named_dict(backward_ops)
with open(op_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f)
# add api version info into api
for api_version in api_versions:
forward_api_dict[api_version['op']]['version'] = api_version['version']
op_versions = yaml.safe_load(f)
# add op version info into op
for op_version in op_versions:
forward_op_dict[op_version['op']]['version'] = op_version['version']
with open(op_compat_yaml_path, "rt") as f:
api_op_map = yaml.safe_load(f)
op_op_map = yaml.safe_load(f)
for api in apis:
api['op_name'] = api['name']
for bw_api in backward_apis:
bw_api['op_name'] = bw_api['name']
for op in ops:
op['op_name'] = op['name']
for bw_op in backward_ops:
bw_op['op_name'] = bw_op['name']
replace_compat_name(api_op_map, forward_api_dict, backward_api_dict)
replace_compat_name(op_op_map, forward_op_dict, backward_op_dict)
# prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict)
process_invoke_op(forward_op_dict, backward_op_dict)
# fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"]
if forward_name in backward_api_dict:
forward_api = backward_api_dict[forward_name]
if forward_api["backward"] is None:
forward_api["backward"] = name
# fill backward field for an op if another op claims it as forward
for name, backward_op in backward_op_dict.items():
forward_name = backward_op["forward"]["name"]
if forward_name in backward_op_dict:
forward_op = backward_op_dict[forward_name]
if forward_op["backward"] is None:
forward_op["backward"] = name
api_dict = {}
api_dict.update(forward_api_dict)
api_dict.update(backward_api_dict)
op_dict = {}
op_dict.update(forward_op_dict)
op_dict.update(backward_op_dict)
if len(apis) == 0 and len(backward_apis) == 0:
if len(ops) == 0 and len(backward_ops) == 0:
if os.path.isfile(output_op_path):
os.remove(output_op_path)
if os.path.isfile(output_arg_map_path):
......@@ -356,19 +357,19 @@ def main(
op_template = env.get_template('op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(
apis=apis, backward_apis=backward_apis, api_dict=api_dict
ops=ops, backward_ops=backward_ops, op_dict=op_dict
)
f.write(msg)
ks_template = env.get_template('ks.c.j2')
with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(apis=apis, backward_apis=backward_apis)
msg = ks_template.render(ops=ops, backward_ops=backward_ops)
f.write(msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate operator file from api yaml."
description="Generate operator file from op yaml."
)
parser.add_argument(
'--ops_yaml_path', type=str, help="parsed ops yaml file."
......
......@@ -26,7 +26,7 @@ from filters import (
to_pascal_case,
)
from tests import (
is_base_api,
is_base_op,
is_vec,
is_scalar,
is_initializer_list,
......@@ -52,7 +52,7 @@ env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.tests["base_api"] = is_base_api
env.tests["base_op"] = is_base_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
......@@ -60,65 +60,63 @@ env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer
def restruct_io(api):
api["input_dict"] = to_named_dict(api["inputs"])
api["attr_dict"] = to_named_dict(api["attrs"])
api["output_dict"] = to_named_dict(api["outputs"])
return api
def restruct_io(op):
op["input_dict"] = to_named_dict(op["inputs"])
op["attr_dict"] = to_named_dict(op["attrs"])
op["output_dict"] = to_named_dict(op["outputs"])
return op
SPARSE_OP_PREFIX = 'sparse_'
def main(
api_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path
):
with open(api_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
with open(op_yaml_path, "rt") as f:
ops = yaml.safe_load(f)
ops = [restruct_io(op) for op in ops]
forward_op_dict = to_named_dict(ops)
with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
for api in apis:
api['op_name'] = SPARSE_OP_PREFIX + api['name']
api['name'] = api['op_name']
if api["backward"] is not None:
api["backward"] = SPARSE_OP_PREFIX + api["backward"]
for bw_api in backward_apis:
bw_api['op_name'] = SPARSE_OP_PREFIX + bw_api['name']
bw_api['name'] = bw_api['op_name']
if 'invoke' in bw_api:
bw_api['invoke']['args'] = [
param.strip() for param in bw_api['invoke']['args'].split(',')
backward_ops = yaml.safe_load(f)
backward_ops = [restruct_io(op) for op in backward_ops]
backward_op_dict = to_named_dict(backward_ops)
for op in ops:
op['op_name'] = SPARSE_OP_PREFIX + op['name']
op['name'] = op['op_name']
if op["backward"] is not None:
op["backward"] = SPARSE_OP_PREFIX + op["backward"]
for bw_op in backward_ops:
bw_op['op_name'] = SPARSE_OP_PREFIX + bw_op['name']
bw_op['name'] = bw_op['op_name']
if 'invoke' in bw_op:
bw_op['invoke']['args'] = [
param.strip() for param in bw_op['invoke']['args'].split(',')
]
# prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict)
for bw_api in backward_apis:
if 'invoke' in bw_api:
if bw_api['invoke']['func'] in forward_api_dict:
bw_api['invoke']['func'] = (
SPARSE_OP_PREFIX + bw_api['invoke']['func']
process_invoke_op(forward_op_dict, backward_op_dict)
for bw_op in backward_ops:
if 'invoke' in bw_op:
if bw_op['invoke']['func'] in forward_op_dict:
bw_op['invoke']['func'] = (
SPARSE_OP_PREFIX + bw_op['invoke']['func']
)
# fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"]
if forward_name in backward_api_dict:
forward_api = backward_api_dict[forward_name]
if forward_api["backward"] is None:
forward_api["backward"] = name
forward_api["backward"] = SPARSE_OP_PREFIX + forward_api["backward"]
# fill backward field for an op if another op claims it as forward
for name, backward_op in backward_op_dict.items():
forward_name = backward_op["forward"]["name"]
if forward_name in backward_op_dict:
forward_op = backward_op_dict[forward_name]
if forward_op["backward"] is None:
forward_op["backward"] = name
forward_op["backward"] = SPARSE_OP_PREFIX + forward_op["backward"]
api_dict = {}
api_dict.update(forward_api_dict)
api_dict.update(backward_api_dict)
op_dict = {}
op_dict.update(forward_op_dict)
op_dict.update(backward_op_dict)
if len(apis) == 0 and len(backward_apis) == 0:
if len(ops) == 0 and len(backward_ops) == 0:
if os.path.isfile(output_op_path):
os.remove(output_op_path)
if os.path.isfile(output_arg_map_path):
......@@ -128,19 +126,19 @@ def main(
op_template = env.get_template('sparse_op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(
apis=apis, backward_apis=backward_apis, api_dict=api_dict
ops=ops, backward_ops=backward_ops, op_dict=op_dict
)
f.write(msg)
ks_template = env.get_template('sparse_ks.c.j2')
with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(apis=apis, backward_apis=backward_apis)
msg = ks_template.render(ops=ops, backward_ops=backward_ops)
f.write(msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate operator file from api yaml."
description="Generate operator file from op yaml."
)
parser.add_argument(
'--ops_yaml_path', type=str, help="parsed sparse ops yaml file."
......
......@@ -16,33 +16,33 @@ import argparse
import yaml
from parse_utils import parse_api_entry
from parse_utils import parse_op_entry
def main(api_yaml_path, output_path, backward):
with open(api_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
if apis is None:
apis = []
def main(op_yaml_path, output_path, backward):
with open(op_yaml_path, "rt") as f:
ops = yaml.safe_load(f)
if ops is None:
ops = []
else:
apis = [
parse_api_entry(api, "backward_op" if backward else "op")
for api in apis
ops = [
parse_op_entry(op, "backward_op" if backward else "op")
for op in ops
]
with open(output_path, "wt") as f:
yaml.safe_dump(apis, f, default_flow_style=None, sort_keys=False)
yaml.safe_dump(ops, f, default_flow_style=None, sort_keys=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Parse api yaml into canonical format."
description="Parse op yaml into canonical format."
)
parser.add_argument('--api_yaml_path', type=str, help="api yaml file.")
parser.add_argument('--op_yaml_path', type=str, help="op yaml file.")
parser.add_argument(
"--output_path", type=str, help="path to save parsed yaml file."
)
parser.add_argument("--backward", action="store_true", default=False)
args = parser.parse_args()
main(args.api_yaml_path, args.output_path, args.backward)
main(args.op_yaml_path, args.output_path, args.backward)
......@@ -28,7 +28,7 @@ def to_named_dict(items: List[Dict]) -> Dict[str, Dict]:
return named_dict
def parse_arg(api_name: str, s: str) -> Dict[str, str]:
def parse_arg(op_name: str, s: str) -> Dict[str, str]:
"""parse an argument in following formats:
1. typename name
2. typename name = default_value
......@@ -36,19 +36,19 @@ def parse_arg(api_name: str, s: str) -> Dict[str, str]:
typename, rest = [item.strip() for item in s.split(" ", 1)]
assert (
len(typename) > 0
), f"The arg typename should not be empty. Please check the args of {api_name} in yaml."
), f"The arg typename should not be empty. Please check the args of {op_name} in yaml."
assert (
rest.count("=") <= 1
), f"There is more than 1 = in an arg in {api_name}"
), f"There is more than 1 = in an arg in {op_name}"
if rest.count("=") == 1:
name, default_value = [item.strip() for item in rest.split("=", 1)]
assert (
len(name) > 0
), f"The arg name should not be empty. Please check the args of {api_name} in yaml."
), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
assert (
len(default_value) > 0
), f"The default value should not be empty. Please check the args of {api_name} in yaml."
), f"The default value should not be empty. Please check the args of {op_name} in yaml."
return {
"typename": typename,
"name": name,
......@@ -58,17 +58,17 @@ def parse_arg(api_name: str, s: str) -> Dict[str, str]:
name = rest.strip()
assert (
len(name) > 0
), f"The arg name should not be empty. Please check the args of {api_name} in yaml."
), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
return {"typename": typename, "name": name}
def parse_input_and_attr(
api_name: str, arguments: str
op_name: str, arguments: str
) -> Tuple[List, List, Dict, Dict]:
args_str = arguments.strip()
assert args_str.startswith('(') and args_str.endswith(')'), (
f"Args declaration should start with '(' and end with ')', "
f"please check the args of {api_name} in yaml."
f"please check the args of {op_name} in yaml."
)
args_str = args_str[1:-1]
args = parse_plain_list(args_str)
......@@ -79,13 +79,13 @@ def parse_input_and_attr(
met_attr_with_default_value = False
for arg in args:
item = parse_arg(api_name, arg)
item = parse_arg(op_name, arg)
typename = item["typename"]
name = item["name"]
if is_input(typename):
assert len(attrs) == 0, (
f"The input Tensor should appear before attributes. "
f"please check the position of {api_name}:input({name}) "
f"please check the position of {op_name}:input({name}) "
f"in yaml."
)
inputs.append(item)
......@@ -93,16 +93,16 @@ def parse_input_and_attr(
if met_attr_with_default_value:
assert (
"default_value" in item
), f"{api_name}: Arguments with default value should not precede those without default value"
), f"{op_name}: Arguments with default value should not precede those without default value"
elif "default_value" in item:
met_attr_with_default_value = True
attrs.append(item)
else:
raise KeyError(f"{api_name}: Invalid argument type {typename}.")
raise KeyError(f"{op_name}: Invalid argument type {typename}.")
return inputs, attrs
def parse_output(api_name: str, s: str) -> Dict[str, str]:
def parse_output(op_name: str, s: str) -> Dict[str, str]:
"""parse an output, typename or typename(name)."""
match = re.search(
r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
......@@ -116,12 +116,12 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]:
size_expr = size_expr[1:-1] if size_expr is not None else None
assert is_output(typename), (
f"Invalid output type: {typename} in api: {api_name}."
f"Invalid output type: {typename} in op : {op_name}."
f"Supported types are Tensor and Tensor[]"
)
if size_expr is not None:
assert is_vec(typename), (
f"Invalid output size: output {name} in api: {api_name} is "
f"Invalid output size: output {name} in op : {op_name} is "
f"not a vector but has size expr"
)
return {"typename": typename, "name": name, "size": size_expr}
......@@ -129,11 +129,11 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]:
return {"typename": typename, "name": name}
def parse_outputs(api_name: str, outputs: str) -> List[Dict]:
def parse_outputs(op_name: str, outputs: str) -> List[Dict]:
outputs = parse_plain_list(outputs, sep=",")
output_items = []
for output in outputs:
output_items.append(parse_output(api_name, output))
output_items.append(parse_output(op_name, output))
return output_items
......@@ -157,9 +157,7 @@ def parse_plain_list(s: str, sep=",") -> List[str]:
return items
def parse_kernel(
api_name: str, kernel_config: Dict[str, Any]
) -> Dict[str, Any]:
def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
# kernel :
# func : [], Kernel functions (example: scale, scale_sr)
# param : [], Input params of kernel
......@@ -205,14 +203,14 @@ def parse_kernel(
'selected_rows',
'sparse_coo',
'sparse_csr',
], f"{api_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
], f"{op_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
for item in outputs:
assert item in [
'dense',
'selected_rows',
'sparse_coo',
'sparse_csr',
], f"{api_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
], f"{op_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
return (inputs, outputs)
......@@ -225,7 +223,7 @@ def parse_kernel(
return kernel
def parse_inplace(api_name: str, inplace_cfg: str) -> Dict[str, str]:
def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
inplace_map = {}
inplace_cfg = inplace_cfg.lstrip("(").rstrip(")")
pairs = parse_plain_list(inplace_cfg)
......@@ -235,7 +233,7 @@ def parse_inplace(api_name: str, inplace_cfg: str) -> Dict[str, str]:
return inplace_map
def parse_invoke(api_name: str, invoke_config: str) -> Dict[str, Any]:
def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]:
invoke_config = invoke_config.strip()
func, rest = invoke_config.split("(", 1)
func = func.strip()
......@@ -245,28 +243,28 @@ def parse_invoke(api_name: str, invoke_config: str) -> Dict[str, Any]:
def extract_type_and_name(records: List[Dict]) -> List[Dict]:
"""extract type and name from forward call, it is simpler than forward api."""
"""extract type and name from forward call, it is simpler than forward op ."""
extracted = [
{"name": item["name"], "typename": item["typename"]} for item in records
]
return extracted
def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]:
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
# op_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
result = re.search(
r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
forward_config,
)
api = result.group("op")
outputs = parse_outputs(api_name, result.group("outputs"))
op = result.group("op")
outputs = parse_outputs(op_name, result.group("outputs"))
outputs = extract_type_and_name(outputs)
inputs, attrs = parse_input_and_attr(api_name, result.group("args"))
inputs, attrs = parse_input_and_attr(op_name, result.group("args"))
inputs = extract_type_and_name(inputs)
attrs = extract_type_and_name(attrs)
forward_cfg = {
"name": api,
"name": op,
"inputs": inputs,
"attrs": attrs,
"outputs": outputs,
......@@ -274,10 +272,10 @@ def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]:
return forward_cfg
def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
api_name = api_entry[name_field]
inputs, attrs = parse_input_and_attr(api_name, api_entry["args"])
outputs = parse_outputs(api_name, api_entry["output"])
def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
op_name = op_entry[name_field]
inputs, attrs = parse_input_and_attr(op_name, op_entry["args"])
outputs = parse_outputs(op_name, op_entry["output"])
# validate default value of DataType and DataLayout
for attr in attrs:
......@@ -287,14 +285,14 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
if typename == "DataType":
assert (
"DataType" in default_value
), f"invalid DataType default value in {api_name}"
), f"invalid DataType default value in {op_name}"
# remove namespace
default_value = default_value[default_value.find("DataType") :]
attr["default_value"] = default_value
elif typename == "DataLayout":
assert (
"DataLayout" in default_value
), f"invalid DataLayout default value in {api_name}"
), f"invalid DataLayout default value in {op_name}"
default_value = default_value[
default_value.find("DataLayout") :
]
......@@ -307,12 +305,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
# add optional tag for every input
for input in inputs:
input["optional"] = False
if "optional" in api_entry:
optional_args = parse_plain_list(api_entry["optional"])
if "optional" in op_entry:
optional_args = parse_plain_list(op_entry["optional"])
for name in optional_args:
assert (
name in input_names
), f"{api_name} has an optional input: '{name}' which is not an input."
), f"{op_name} has an optional input: '{name}' which is not an input."
for input in inputs:
if input["name"] in optional_args:
input["optional"] = True
......@@ -320,12 +318,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
# add intermediate tag for every output
for output in outputs:
output["intermediate"] = False
if "intermediate" in api_entry:
intermediate_outs = parse_plain_list(api_entry["intermediate"])
if "intermediate" in op_entry:
intermediate_outs = parse_plain_list(op_entry["intermediate"])
for name in intermediate_outs:
assert (
name in output_names
), f"{api_name} has an intermediate output: '{name}' which is not an output."
), f"{op_name} has an intermediate output: '{name}' which is not an output."
for output in outputs:
if output["name"] in intermediate_outs:
output["intermediate"] = True
......@@ -333,12 +331,12 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
# add no_need_buffer for every input
for input in inputs:
input["no_need_buffer"] = False
if "no_need_buffer" in api_entry:
no_buffer_args = parse_plain_list(api_entry["no_need_buffer"])
if "no_need_buffer" in op_entry:
no_buffer_args = parse_plain_list(op_entry["no_need_buffer"])
for name in no_buffer_args:
assert (
name in input_names
), f"{api_name} has an no buffer input: '{name}' which is not an input."
), f"{op_name} has an no buffer input: '{name}' which is not an input."
for input in inputs:
if input["name"] in no_buffer_args:
input["no_need_buffer"] = True
......@@ -347,34 +345,34 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
# TODO(chenfeiyu): data_transform
api = {
"name": api_name,
op = {
"name": op_name,
"inputs": inputs,
"attrs": attrs,
"outputs": outputs,
"no_need_buffer": no_buffer_args,
}
# invokes another api?
is_base_api = "invoke" not in api_entry
# invokes another op ?
is_base_op = "invoke" not in op_entry
if is_base_api:
if is_base_op:
# kernel
kernel = parse_kernel(api_name, api_entry["kernel"])
kernel = parse_kernel(op_name, op_entry["kernel"])
if kernel["param"] is None:
kernel["param"] = input_names + attr_names
# infer meta
infer_meta = parse_infer_meta(api_entry["infer_meta"])
infer_meta = parse_infer_meta(op_entry["infer_meta"])
if infer_meta["param"] is None:
infer_meta["param"] = copy(kernel["param"])
# inplace
if "inplace" in api_entry:
inplace_pairs = parse_inplace(api_name, api_entry["inplace"])
if "inplace" in op_entry:
inplace_pairs = parse_inplace(op_name, op_entry["inplace"])
else:
inplace_pairs = None
api.update(
op.update(
{
"infer_meta": infer_meta,
"kernel": kernel,
......@@ -383,47 +381,47 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
)
else:
# invoke
invoke = parse_invoke(api_name, api_entry["invoke"])
api["invoke"] = invoke
invoke = parse_invoke(op_name, op_entry["invoke"])
op["invoke"] = invoke
# backward
if "backward" in api_entry:
backward = api_entry["backward"]
if "backward" in op_entry:
backward = op_entry["backward"]
else:
backward = None
api["backward"] = backward
op["backward"] = backward
# forward for backward_apis
is_backward_api = name_field == "backward_op"
if is_backward_api:
if "forward" in api_entry:
forward = parse_forward(api_name, api_entry["forward"])
# forward for backward_ops
is_backward_op = name_field == "backward_op"
if is_backward_op:
if "forward" in op_entry:
forward = parse_forward(op_name, op_entry["forward"])
# validate_fb
validate_backward_inputs(
api_name, forward["inputs"], forward["outputs"], inputs
op_name, forward["inputs"], forward["outputs"], inputs
)
validate_backward_attrs(api_name, forward["attrs"], attrs)
validate_backward_outputs(api_name, forward["inputs"], outputs)
validate_backward_attrs(op_name, forward["attrs"], attrs)
validate_backward_outputs(op_name, forward["inputs"], outputs)
else:
forward = None
api["forward"] = forward
return api
op["forward"] = forward
return op
def validate_backward_attrs(api, forward_attrs, backward_attrs):
def validate_backward_attrs(op, forward_attrs, backward_attrs):
if len(forward_attrs) >= len(backward_attrs):
return
num_exceptional_attrs = len(backward_attrs) - len(forward_attrs)
# this is a not-that-clean trick to allow backward api to has more attrs
# than the forward api, as long as they all have default value
# this is a not-that-clean trick to allow backward op to has more attrs
# than the forward op , as long as they all have default value
for i in range(-num_exceptional_attrs, 0):
assert (
"default_value" in backward_attrs[i]
), f"{api} has exceptional attr without default value"
), f"{op } has exceptional attr without default value"
def validate_backward_inputs(
api, forward_inputs, forward_outputs, backward_inputs
op, forward_inputs, forward_outputs, backward_inputs
):
foward_input_names = [item["name"] for item in forward_inputs]
forward_output_names = [item["name"] for item in forward_outputs]
......@@ -431,47 +429,47 @@ def validate_backward_inputs(
assert len(backward_input_names) <= len(foward_input_names) + 2 * len(
forward_output_names
), f"{api} has too many inputs."
), f"{op } has too many inputs."
def validate_backward_outputs(api, forward_inputs, backward_outputs):
def validate_backward_outputs(op, forward_inputs, backward_outputs):
assert len(backward_outputs) <= len(
forward_inputs
), f"{api} has too many outputs"
), f"{op } has too many outputs"
def cross_validate(apis):
for name, api in apis.items():
if "forward" in api:
fw_call = api["forward"]
def cross_validate(ops):
for name, op in ops.items():
if "forward" in op:
fw_call = op["forward"]
fw_name = fw_call["name"]
if fw_name not in apis:
if fw_name not in ops:
print(
f"Something Wrong here, this backward api({name})'s forward api({fw_name}) does not exist."
f"Something Wrong here, this backward op ({name})'s forward op ({fw_name}) does not exist."
)
else:
fw_api = apis[fw_name]
if "backward" not in fw_api or fw_api["backward"] is None:
fw_op = ops[fw_name]
if "backward" not in fw_op or fw_op["backward"] is None:
print(
f"Something Wrong here, {name}'s forward api({fw_name}) does not claim {name} as its backward."
f"Something Wrong here, {name}'s forward op ({fw_name}) does not claim {name} as its backward."
)
else:
assert (
fw_api["backward"] == name
fw_op["backward"] == name
), f"{name}: backward and forward name mismatch"
assert len(fw_call["inputs"]) <= len(
fw_api["inputs"]
), f"{name}: forward call has more inputs than the api"
for (input, input_) in zip(fw_call["inputs"], fw_api["inputs"]):
fw_op["inputs"]
), f"{name}: forward call has more inputs than the op "
for (input, input_) in zip(fw_call["inputs"], fw_op["inputs"]):
assert (
input["typename"] == input_["typename"]
), f"type mismatch in {name} and {fw_name}"
assert len(fw_call["attrs"]) <= len(
fw_api["attrs"]
), f"{name}: forward call has more attrs than the api"
for (attr, attr_) in zip(fw_call["attrs"], fw_api["attrs"]):
fw_op["attrs"]
), f"{name}: forward call has more attrs than the op "
for (attr, attr_) in zip(fw_call["attrs"], fw_op["attrs"]):
if attr["typename"] == "Scalar":
# special case for Scalar, fw_call can omit the type
assert re.match(
......@@ -483,10 +481,10 @@ def cross_validate(apis):
), f"type mismatch in {name} and {fw_name}"
assert len(fw_call["outputs"]) == len(
fw_api["outputs"]
), f"{name}: forward call has more outputs than the api"
fw_op["outputs"]
), f"{name}: forward call has more outputs than the op "
for (output, output_) in zip(
fw_call["outputs"], fw_api["outputs"]
fw_call["outputs"], fw_op["outputs"]
):
assert (
output["typename"] == output_["typename"]
......
{% from "operator_utils.c.j2" import name_map, register_name_map, register_base_kernel_name %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
// this file is generated by paddle/phi/op/yaml/generator/generate_op.py, do not edit.
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace phi {
{% for api in apis %}
{% if api is base_api %}
{{name_map(api)}}
{% for op in ops %}
{% if op is base_op %}
{{name_map(op)}}
{% endif %}
{% endfor %}
{% for api in backward_apis %}
{% if api is base_api %}
{{name_map(api)}}
{% for op in backward_ops %}
{% if op is base_op %}
{{name_map(op)}}
{% endif %}
{% endfor %}
} // namespace phi
{% for api in apis + backward_apis %}
{% if api["name"] != api["op_name"] %}
{{register_base_kernel_name(api)}}
{% for op in ops + backward_ops %}
{% if op["name"] != op["op_name"] %}
{{register_base_kernel_name(op)}}
{% endif %}
{% if api is base_api %}
{{register_name_map(api)}}
{% if op is base_op %}
{{register_name_map(op)}}
{% endif %}
{% endfor %}
......@@ -18,32 +18,32 @@ namespace operators {
using paddle::framework::GradVarName;
{% for api in apis %}
{% if api is base_api %}
{% for op in ops %}
{% if op is base_op %}
{{op_maker(api)}}
{{op_maker(op)}}
{{operator(api)}}
{{operator(op)}}
{% endif %}
{% endfor %}
{% for api in backward_apis %}
{% if api is base_api %}
{% for op in backward_ops %}
{% if op is base_op %}
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}}
{{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
{{operator(api)}}
{{operator(op)}}
{% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}}
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %}
{% endfor %}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
{% for api in apis + backward_apis %}
{% if api is base_api %}
{{register_op_with_components(api)}}
{{register_op_version(api)}}
{% for op in ops + backward_ops %}
{% if op is base_op %}
{{register_op_with_components(op)}}
{{register_op_version(op)}}
{% endif %}
{% endfor %}
{# ----------------------------- op maker ----------------------------------- #}
{% macro op_maker(api) %}
{% set api_name = api["op_name"] %}
class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker {
{% macro op_maker(op) %}
{% set op_name = op["op_name"] %}
class {{op_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
{% filter indent(4, True) %}
{% for input in api["inputs"] %}
{{add_input(loop.index0, input, api_name)}};
{% for input in op["inputs"] %}
{{add_input(loop.index0, input, op_name)}};
{% endfor %}
{% for output in api["outputs"] %}
{{add_output(loop.index0, output, api_name)}};
{% for output in op["outputs"] %}
{{add_output(loop.index0, output, op_name)}};
{% endfor %}
{% for attr in api["attrs"] %}
{% if attr["name"] in api["kernel"]["param"] %}
{{add_attr(loop.index0, attr, api_name)}};
{% for attr in op["attrs"] %}
{% if attr["name"] in op["kernel"]["param"] %}
{{add_attr(loop.index0, attr, op_name)}};
{% endif %}
{% endfor %}
{% endfilter %}
AddComment(R"DOC(
TODO: Documentation of {{api_name}} op.
TODO: Documentation of {{op_name}} op.
)DOC");
}
};
......@@ -76,7 +76,7 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty
{%- endif %}
{%- endmacro %}
{# process default value for attributes, some attribute has different types and different default values in api & opmaker #}
{# process default value for attributes, some attribute has different types and different default values in op & opmaker #}
{% macro process_default_value(attr) %}{# inline #}
{% set default_value = attr["default_value"] %}
{% set typename = attr["typename"] %}
......@@ -97,22 +97,22 @@ static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}}
{# --------------------------------------- name mapping ---------------------------------------------- #}
{% macro name_map(api) %}
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}};
{% macro name_map(op) %}
KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = op["kernel"]["param"] %}
{{get_input_list(op["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs;
{% for attr in api["attrs"]%}
{% for attr in op["attrs"]%}
{% filter indent(2)%}
{{get_an_attr(attr)}};
{% endfilter %}
{% endfor %}
{{get_output_list(api["outputs"], kernel_args)}};
{% if api["kernel"]["func"] | length == 1 %}
KernelSignature sig("{{api["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs));
{{get_output_list(op["outputs"], kernel_args)}};
{% if op["kernel"]["func"] | length == 1 %}
KernelSignature sig("{{op["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs));
return sig;
{% else %}{# it has kernel for selected rows #}
const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{api["kernel"]["func"][1]}}" : "{{api["kernel"]["func"][0]}}";
const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{op["kernel"]["func"][1]}}" : "{{op["kernel"]["func"][0]}}";
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig;
{%endif%}
......@@ -121,9 +121,9 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu
/*
******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping:
All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}}
{{op | cartesian_prod_mapping}}
******************************************************************
*/
{% endmacro %}
......@@ -152,20 +152,20 @@ ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- endfor %}
{%- endmacro %}
{% macro sparse_op_name_map(api) %}
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}};
{% macro sparse_op_name_map(op) %}
KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = op["kernel"]["param"] %}
{{get_input_list(op["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs;
{% for attr in api["attrs"]%}
{% for attr in op["attrs"]%}
{% filter indent(2)%}
{{get_an_attr(attr)}};
{% endfilter %}
{% endfor %}
{{get_output_list(api["outputs"], kernel_args)}};
{{get_output_list(op["outputs"], kernel_args)}};
const char* kernel_name = "unregistered";
{{get_kernel_dispatch(api["inputs"], api["kernel"])}}
{{get_kernel_dispatch(op["inputs"], op["kernel"])}}
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig;
}
......@@ -173,19 +173,19 @@ KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const Argu
/*
******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping:
All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}}
{{op | cartesian_prod_mapping}}
******************************************************************
*/
{% endmacro %}
{% macro register_base_kernel_name(api) %}
PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}});
{% macro register_base_kernel_name(op) %}
PD_REGISTER_BASE_KERNEL_NAME({{op["op_name"]}}, {{op["name"]}});
{%- endmacro %}
{% macro register_name_map(api) %}
PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["op_name"] | to_pascal_case}}OpArgumentMapping);
{% macro register_name_map(op) %}
PD_REGISTER_ARG_MAPPING_FN({{op["op_name"]}}, phi::{{op["op_name"] | to_pascal_case}}OpArgumentMapping);
{%- endmacro %}
{% macro get_input_list(inputs, kernel_args) %}{# inline #}
......@@ -228,14 +228,14 @@ paddle::small_vector<const char*> outputs {
}
{%- endmacro %}
{% macro get_expected_kernel(api) %}
{% set kernel = api["kernel"] %}
{% macro get_expected_kernel(op) %}
{% set kernel = op["kernel"] %}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
{%if kernel["data_type"] is not none %}{# data type ---------------------------------#}
{% if kernel["data_type"]["candidates"] | length == 1 %}
{% set data_type_arg = kernel["data_type"]["candidates"][0] %}
{% set inputs = api["inputs"] | map(attribute="name") | list %}
{% set inputs = op["inputs"] | map(attribute="name") | list %}
{% if data_type_arg in inputs %}
auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}});
{% else %}{# it is an attribute and probably named dtype#}
......@@ -254,68 +254,68 @@ framework::OpKernelType GetExpectedKernelType(
{% endmacro %}
{# --------------------------------------- operator ---------------------------------------------- #}
{% macro operator(api) %}
class {{api["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKernel {
{% macro operator(op) %}
class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
{# ----------- get expected kernel type function -------------------------- #}
{% set kernel = api["kernel"] %}
{% set kernel = op["kernel"] %}
{% if kernel["data_type"] is not none %}
protected:
{% filter indent(2, True)%}
{{get_expected_kernel(api)}}
{{get_expected_kernel(op)}}
{% endfilter %}
{% endif %}
};
DECLARE_INFER_SHAPE_FUNCTOR({{api["op_name"]}}, {{api["op_name"] | to_pascal_case}}InferShapeFunctor,
PD_INFER_META(phi::{{api["infer_meta"]["func"]}}));
DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor,
PD_INFER_META(phi::{{op["infer_meta"]["func"]}}));
{# inplace inferer #}
{% if api["inplace"] is not none %}
{% if op["inplace"] is not none %}
{% set inplace_map %}
{% for source, target in api["inplace"].items() %}
{% for source, target in op["inplace"].items() %}
{{"{"}}{{target | to_opmaker_name}}, {{source | to_opmaker_name}}{{"}"}}{{", " if not loop.last}}
{%- endfor %}
{%- endset %}
DECLARE_INPLACE_OP_INFERER({{api["op_name"] | to_pascal_case}}InplaceInferer,
DECLARE_INPLACE_OP_INFERER({{op["op_name"] | to_pascal_case}}InplaceInferer,
{{inplace_map}});
{% endif %}
{# no_need_buffer inferer #}
{% if api["no_need_buffer"] is not none %}
DECLARE_NO_NEED_BUFFER_VARS_INFERER({{api["op_name"] | to_pascal_case}}NoNeedBufferVarInferer,
{{api["no_need_buffer"] | map("to_opmaker_name") | join(", ")}});
{% if op["no_need_buffer"] is not none %}
DECLARE_NO_NEED_BUFFER_VARS_INFERER({{op["op_name"] | to_pascal_case}}NoNeedBufferVarInferer,
{{op["no_need_buffer"] | map("to_opmaker_name") | join(", ")}});
{% endif %}
{% endmacro%}
{% macro register_op_with_components(api) %}
{% set name = api["op_name"] %}
{% macro register_op_with_components(op) %}
{% set name = op["op_name"] %}
REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in api %}{# it is a forward api #}
{% if not "forward" in op %}{# it is a forward op #}
ops::{{name | to_pascal_case}}OpMaker,
{% endif %}
{% if "backward" in api and api["backward"] is not none %}{# backward #}
{% set backward_name = api["backward"] %}
{% if "backward" in op and op["backward"] is not none %}{# backward #}
{% set backward_name = op["backward"] %}
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>,
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>,
{% else %}
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
{% endif %}
{% if api is supports_inplace %}{# inplace#}
{% if op is supports_inplace %}{# inplace#}
ops::{{name | to_pascal_case}}InplaceInferer,
{% endif %}
{% if api is supports_no_need_buffer %}{# no_need_buffer #}
{% if op is supports_no_need_buffer %}{# no_need_buffer #}
ops::{{name | to_pascal_case}}NoNeedBufferVarInferer,
{% endif %}
ops::{{name | to_pascal_case}}InferShapeFunctor);
{% endmacro %}
{% macro register_op_version(api) %}
{% if "version" in api %}
{% set name = api["op_name"] %}
{% macro register_op_version(op) %}
{% if "version" in op %}
{% set name = op["op_name"] %}
REGISTER_OP_VERSION({{name}})
{% for checkpoint in api["version"]%}
{% for checkpoint in op["version"]%}
.AddCheckpoint(
R"ROC({{checkpoint["checkpoint"]}})ROC",
paddle::framework::compatible::OpVersionDesc()
......@@ -354,14 +354,14 @@ REGISTER_OP_VERSION({{name}})
{# --------------------------------------- backward op maker ---------------------------------------------- #}
{% macro backward_op_maker(api, forward_api) %}
{% set name = api["op_name"] %}
{% set forward_input_names = api["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_output_names = api["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_attr_names = api["forward"]["attrs"] | map(attribute="name") | list %}
{% set forward_input_orig_names = forward_api["inputs"] | map(attribute="name") | list %}
{% set forward_output_orig_names = forward_api["outputs"] | map(attribute="name") | list %}
{% set forward_attr_orig_names = forward_api["attrs"] | map(attribute="name") | list %}
{% macro backward_op_maker(op, forward_op ) %}
{% set name = op["op_name"] %}
{% set forward_input_names = op["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_output_names = op["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_attr_names = op["forward"]["attrs"] | map(attribute="name") | list %}
{% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %}
{% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %}
{% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %}
template <typename T>
class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -371,7 +371,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("{{name}}");
{% for input in api["inputs"] %}
{% for input in op["inputs"] %}
grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["name"],
forward_input_names,
......@@ -380,7 +380,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
forward_output_orig_names)}});
{% endfor %}
{% for output in api["outputs"] %}
{% for output in op["outputs"] %}
grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["name"],
forward_input_names,
......@@ -390,7 +390,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% endfor %}
grad_op->SetAttrMap(this->Attrs());
{% for attr in api["attrs"] %}
{% for attr in op["attrs"] %}
{% set attr_name = attr["name"] %}
{% if attr_name in forward_attr_names %}
{% if attr["typename"] == "IntArray" %}
......
......@@ -5,20 +5,20 @@
namespace phi {
{% for api in apis %}
{% if api is base_api %}
{{sparse_op_name_map(api)}}
{% for op in ops %}
{% if op is base_op %}
{{sparse_op_name_map(op)}}
{% endif %}
{% endfor %}
{% for api in backward_apis %}
{% if api is base_api %}
{{sparse_op_name_map(api)}}
{% for op in backward_ops %}
{% if op is base_op %}
{{sparse_op_name_map(op)}}
{% endif %}
{% endfor %}
} // namespace phi
{% for api in apis + backward_apis %}
{% if api is base_api %}
{{register_name_map(api)}}
{% for op in ops + backward_ops %}
{% if op is base_op %}
{{register_name_map(op)}}
{% endif %}
{% endfor %}
......@@ -19,31 +19,31 @@ namespace operators {
using paddle::framework::GradVarName;
{% for api in apis %}
{% if api is base_api %}
{% for op in ops %}
{% if op is base_op %}
{{op_maker(api)}}
{{op_maker(op)}}
{{operator(api)}}
{{operator(op)}}
{% endif %}
{% endfor %}
{% for api in backward_apis %}
{% if api is base_api %}
{% for op in backward_ops %}
{% if op is base_op %}
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}}
{{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
{{operator(api)}}
{{operator(op)}}
{% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}}
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %}
{% endfor %}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
{% for api in apis + backward_apis %}
{% if api is base_api %}
{{register_op_with_components(api)}}
{% for op in ops + backward_ops %}
{% if op is base_op %}
{{register_op_with_components(op)}}
{% endif %}
{% endfor %}
......@@ -41,20 +41,20 @@ def is_initializer_list(s):
return s == "{}"
def is_base_api(api):
return "kernel" in api and "infer_meta" in api
def is_base_op(op):
return "kernel" in op and "infer_meta" in op
def supports_selected_rows_kernel(api):
return is_base_api(api) and len(api["kernel"]["func"]) == 2
def supports_selected_rows_kernel(op):
return is_base_op(op) and len(op["kernel"]["func"]) == 2
def supports_inplace(api):
return api['inplace'] is not None
def supports_inplace(op):
return op['inplace'] is not None
def supports_no_need_buffer(api):
for input in api["inputs"]:
def supports_no_need_buffer(op):
for input in op["inputs"]:
if input["no_need_buffer"]:
return True
return False
......@@ -15,8 +15,6 @@ add_subdirectory(backends)
add_subdirectory(kernels)
# phi infermeta
add_subdirectory(infermeta)
# phi operator definitions
add_subdirectory(ops)
# phi tools
add_subdirectory(tools)
# phi tests
......@@ -36,7 +34,6 @@ set(PHI_DEPS
arg_map_context
infermeta
lod_utils
op_compat_infos
sparse_csr_tensor
sparse_coo_tensor
string_tensor
......
......@@ -94,204 +94,10 @@ set(wrapped_infermeta_header_file
set(wrapped_infermeta_source_file
${CMAKE_SOURCE_DIR}/paddle/phi/infermeta/generated.cc)
# op extra info file
set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/ops_extra_info_gen.py)
set(op_compat_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc)
if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED)
endif()
# install extra dependencies
if(${PYTHON_VERSION_STRING} VERSION_LESS "3.6.2")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml
typing-extensions>=4.1.1 jinja2==2.11.3)
else()
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install -U pyyaml jinja2
typing-extensions)
endif()
# parse apis
set(parsed_api_dir ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/parsed_apis)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_sparse_ops_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc)
set(generated_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc)
set(generated_sparse_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc)
message(
"parse api yamls:
- ${api_yaml_file}
- ${legacy_api_yaml_file}
- ${bw_api_yaml_file}
- ${legacy_bw_api_yaml_file}")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_api_dir}
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./ops.yaml
--output_path ./parsed_apis/ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_ops.yaml --output_path ./parsed_apis/legacy_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./backward.yaml
--output_path ./parsed_apis/backward_ops.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_backward.yaml --output_path
./parsed_apis/legacy_backward_ops.parsed.yaml --backward
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_ops.yaml --output_path ./parsed_apis/sparse_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_backward.yaml --output_path
./parsed_apis/sparse_backward.parsed.yaml --backward RESULTS_VARIABLE
_results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "api yaml parsing failed, exiting.")
endif()
endforeach()
# validation of api yamls
message("validate api yaml:
- ${parsed_api_dir}/ops.parsed.yaml
- ${parsed_api_dir}/backward_ops.parsed.yaml")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/ops.parsed.yaml ./parsed_apis/legacy_ops.parsed.yaml
--backward_yaml_paths ./parsed_apis/backward_ops.parsed.yaml
./parsed_apis/legacy_backward_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/sparse_ops.parsed.yaml --backward_yaml_paths
./parsed_apis/sparse_backward.parsed.yaml
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
endforeach()
# code generation for op, op makers, and argument mapping functions
message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
)
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/generate_op.py --ops_yaml_path
./parsed_apis/ops.parsed.yaml --backward_yaml_path
./parsed_apis/backward_ops.parsed.yaml --op_version_yaml_path
op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path
"${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
COMMAND
${PYTHON_EXECUTABLE} generator/generate_sparse_op.py --ops_yaml_path
./parsed_apis/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_apis/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_op_path}.tmp" "${generated_op_path}")
message("copy if different ${generated_op_path}.tmp ${generated_op_path}")
elseif(EXISTS "${generated_op_path}.tmp")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_op_path}.tmp"
"${generated_op_path}")
message("copy ${generated_op_path}.tmp ${generated_op_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f "${generated_op_path}")
message("remove ${generated_op_path}")
endif()
if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS
"${generated_sparse_ops_path}")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}")
message(
"copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}"
)
elseif(EXISTS "${generated_sparse_ops_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp"
"${generated_sparse_ops_path}")
message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_ops_path}")
message("remove ${generated_sparse_ops_path}")
endif()
if(EXISTS "${generated_argument_mapping_path}.tmp"
AND EXISTS "${generated_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy if different ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
elseif(EXISTS "${generated_argument_mapping_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_argument_mapping_path}.tmp"
"${generated_argument_mapping_path}")
message(
"copy ${generated_argument_mapping_path}.tmp ${generated_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_argument_mapping_path}")
message("remove ${generated_argument_mapping_path}")
endif()
if(EXISTS "${generated_sparse_argument_mapping_path}.tmp"
AND EXISTS "${generated_sparse_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_argument_mapping_path}")
message("remove ${generated_sparse_argument_mapping_path}")
endif()
# generate ops extra info
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path
${op_compat_yaml_file} --ops_extra_info_path ${ops_extra_info_file})
message("generate ${ops_extra_info_file}")
# generate forward api
add_custom_command(
OUTPUT ${api_header_file} ${api_source_file}
......
set(op_utils_header
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h.tmp
CACHE INTERNAL "op_args_fns.cc file")
set(op_utils_header_final
${PADDLE_BINARY_DIR}/paddle/phi/ops/compat/signatures.h)
file(
WRITE ${op_utils_header}
"// Generated by the paddle/phi/ops/compat/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${op_utils_header}
"#include \"paddle/phi/core/compat/op_utils.h\"\n\n")
# Automatically generate the registration code of all arg map functions
# and compile the corresponding target to avoid frequent code conflicts
# when writing to same file
register_op_utils(op_compat_infos DEPS op_utils)
copy_if_different(${op_utils_header} ${op_utils_header_final})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册