未验证 提交 4652bee4 编写于 作者: Z zyfncg 提交者: GitHub

Split generated_op.cc into 4 src files [generated_op(1-4).cc] (#50985)

* split generated_op.cc into 4 src files

* fix bug

* fix compile on windows
上级 bb5dd203
...@@ -81,7 +81,7 @@ paddle/fluid/pybind/eager_op_function.cc ...@@ -81,7 +81,7 @@ paddle/fluid/pybind/eager_op_function.cc
tools/nvcc_lazy tools/nvcc_lazy
# these files (directories) are generated before build system generation # these files (directories) are generated before build system generation
paddle/fluid/operators/generated_op.cc paddle/fluid/operators/generated_op*.cc
paddle/fluid/operators/generated_sparse_op.cc paddle/fluid/operators/generated_sparse_op.cc
paddle/fluid/operators/generated_static_op.cc paddle/fluid/operators/generated_static_op.cc
paddle/phi/ops/compat/generated_*.cc paddle/phi/ops/compat/generated_*.cc
......
...@@ -97,9 +97,10 @@ endif() ...@@ -97,9 +97,10 @@ endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_utils backward_infermeta sparse_backward_infermeta static_prim_api) set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_utils backward_infermeta sparse_backward_infermeta static_prim_api)
register_operators(EXCLUDES py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op register_operators(EXCLUDES py_func_op warpctc_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS})
target_link_libraries(run_program_op cuda_graph_with_memory_pool) target_link_libraries(run_program_op cuda_graph_with_memory_pool)
op_library(quantize_linear_op DEPS phi) op_library(quantize_linear_op DEPS phi)
...@@ -200,7 +201,7 @@ elseif(WITH_ROCM) ...@@ -200,7 +201,7 @@ elseif(WITH_ROCM)
else() else()
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3) cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
endif() endif()
cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context generated_op) cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context generated_static_op)
cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
if (WITH_PYTHON) if (WITH_PYTHON)
......
...@@ -30,8 +30,14 @@ endif() ...@@ -30,8 +30,14 @@ endif()
# parse ops # parse ops
set(parsed_op_dir set(parsed_op_dir
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops) ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops)
set(generated_op_path set(generated_op_path_1
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc) ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op1.cc)
set(generated_op_path_2
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op2.cc)
set(generated_op_path_3
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op3.cc)
set(generated_op_path_4
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op4.cc)
set(generated_static_op_path set(generated_static_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_static_op.cc) ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_static_op.cc)
set(generated_sparse_ops_path set(generated_sparse_ops_path
...@@ -118,7 +124,7 @@ endif() ...@@ -118,7 +124,7 @@ endif()
# code generation for op, op makers, and argument mapping functions # code generation for op, op makers, and argument mapping functions
message( message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp "create or remove auto-geneated operators: generated_op(1-4).cc.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp" create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
) )
execute_process( execute_process(
...@@ -129,8 +135,9 @@ execute_process( ...@@ -129,8 +135,9 @@ execute_process(
./parsed_ops/backward_ops.parsed.yaml --op_version_yaml_path ./parsed_ops/backward_ops.parsed.yaml --op_version_yaml_path
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml --op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path --output_op_path "${generated_op_path_1}.tmp" "${generated_op_path_2}.tmp"
"${generated_argument_mapping_path}.tmp" "${generated_op_path_3}.tmp" "${generated_op_path_4}.tmp"
--output_arg_map_path "${generated_argument_mapping_path}.tmp"
RESULT_VARIABLE _result) RESULT_VARIABLE _result)
if(${_result}) if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.") message(FATAL_ERROR "operator codegen failed, exiting.")
...@@ -165,7 +172,10 @@ if(${_result}) ...@@ -165,7 +172,10 @@ if(${_result})
endif() endif()
set(generated_static_files set(generated_static_files
"${generated_op_path}" "${generated_op_path_1}"
"${generated_op_path_2}"
"${generated_op_path_3}"
"${generated_op_path_4}"
"${generated_static_op_path}" "${generated_static_op_path}"
"${generated_sparse_ops_path}" "${generated_sparse_ops_path}"
"${generated_argument_mapping_path}" "${generated_argument_mapping_path}"
...@@ -192,6 +202,16 @@ foreach(generated_static_file ${generated_static_files}) ...@@ -192,6 +202,16 @@ foreach(generated_static_file ${generated_static_files})
endif() endif()
endforeach() endforeach()
# Note(zyfncg): The generated file generated_op.cc has been deleted,
# so we need to clear the generated_op.cc and generated_op.cc.tmp cached in develop environment.
set(old_generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
if(EXISTS "${old_generated_op_path}" OR EXISTS "${old_generated_op_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E remove -f "${old_generated_op_path}"
"${old_generated_op_path}.tmp")
endif()
# op extra info file # op extra info file
set(ops_extra_info_gen_file set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/ops_extra_info_gen.py) ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/ops_extra_info_gen.py)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import math
import os import os
from pathlib import Path from pathlib import Path
...@@ -478,6 +479,29 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict): ...@@ -478,6 +479,29 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict):
] = False ] = False
def split_ops_list(ops, backward_op_dict, split_num):
new_ops_list = []
new_bw_ops_list = []
list_size = math.ceil(len(ops) / split_num)
tmp_ops_list = []
tmp_bw_ops_list = []
for idx, op in enumerate(ops):
tmp_ops_list.append(op)
current_op = op
while (
'backward' in current_op
and current_op['backward'] in backward_op_dict
):
tmp_bw_ops_list.append(backward_op_dict[current_op['backward']])
current_op = backward_op_dict[current_op['backward']]
if (idx + 1) % list_size == 0 or idx == len(ops) - 1:
new_ops_list.append(tmp_ops_list)
new_bw_ops_list.append(tmp_bw_ops_list)
tmp_ops_list = []
tmp_bw_ops_list = []
return new_ops_list, new_bw_ops_list
def main( def main(
ops_yaml_path, ops_yaml_path,
backward_yaml_path, backward_yaml_path,
...@@ -548,13 +572,23 @@ def main( ...@@ -548,13 +572,23 @@ def main(
os.remove(output_arg_map_path) os.remove(output_arg_map_path)
return return
op_template = env.get_template('op.c.j2') op_template = env.get_template('op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render( backward_fluid_op_dict = {}
ops=ops, for bw_op in backward_ops:
backward_ops=backward_ops, backward_fluid_op_dict[bw_op['op_name']] = bw_op
op_dict=op_dict, output_op_files_num = len(output_op_path)
) new_ops_list, new_bw_ops_list = split_ops_list(
f.write(msg) ops, backward_fluid_op_dict, output_op_files_num
)
for idx, output_op_file in enumerate(output_op_path):
with open(output_op_file, "wt") as f:
msg = op_template.render(
ops=new_ops_list[idx],
backward_ops=new_bw_ops_list[idx],
op_dict=op_dict,
)
f.write(msg)
ks_template = env.get_template('ks.c.j2') ks_template = env.get_template('ks.c.j2')
with open(output_arg_map_path, 'wt') as f: with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(ops=ops, backward_ops=backward_ops) msg = ks_template.render(ops=ops, backward_ops=backward_ops)
...@@ -578,7 +612,10 @@ if __name__ == "__main__": ...@@ -578,7 +612,10 @@ if __name__ == "__main__":
'--op_version_yaml_path', type=str, help="ops version yaml file." '--op_version_yaml_path', type=str, help="ops version yaml file."
) )
parser.add_argument( parser.add_argument(
"--output_op_path", type=str, help="path to save generated operators." "--output_op_path",
type=str,
nargs='+',
help="path to save generated operators.",
) )
parser.add_argument( parser.add_argument(
"--output_arg_map_path", "--output_arg_map_path",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册