未验证 提交 4652bee4 编写于 作者: Z zyfncg 提交者: GitHub

Split generated_op.cc into 4 src files [generated_op(1-4).cc] (#50985)

* split generated_op.cc into 4 src files

* fix bug

* fix compile on windows
上级 bb5dd203
......@@ -81,7 +81,7 @@ paddle/fluid/pybind/eager_op_function.cc
tools/nvcc_lazy
# these files (directories) are generated before build system generation
paddle/fluid/operators/generated_op.cc
paddle/fluid/operators/generated_op*.cc
paddle/fluid/operators/generated_sparse_op.cc
paddle/fluid/operators/generated_static_op.cc
paddle/phi/ops/compat/generated_*.cc
......
......@@ -97,9 +97,10 @@ endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_utils backward_infermeta sparse_backward_infermeta static_prim_api)
register_operators(EXCLUDES py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
register_operators(EXCLUDES py_func_op warpctc_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS})
target_link_libraries(run_program_op cuda_graph_with_memory_pool)
op_library(quantize_linear_op DEPS phi)
......@@ -200,7 +201,7 @@ elseif(WITH_ROCM)
else()
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
endif()
cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context generated_op)
cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context generated_static_op)
cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
if (WITH_PYTHON)
......
......@@ -30,8 +30,14 @@ endif()
# parse ops
set(parsed_op_dir
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_op_path_1
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op1.cc)
set(generated_op_path_2
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op2.cc)
set(generated_op_path_3
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op3.cc)
set(generated_op_path_4
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op4.cc)
set(generated_static_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_static_op.cc)
set(generated_sparse_ops_path
......@@ -118,7 +124,7 @@ endif()
# code generation for op, op makers, and argument mapping functions
message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp
"create or remove auto-geneated operators: generated_op(1-4).cc.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
)
execute_process(
......@@ -129,8 +135,9 @@ execute_process(
./parsed_ops/backward_ops.parsed.yaml --op_version_yaml_path
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
--output_op_path "${generated_op_path_1}.tmp" "${generated_op_path_2}.tmp"
"${generated_op_path_3}.tmp" "${generated_op_path_4}.tmp"
--output_arg_map_path "${generated_argument_mapping_path}.tmp"
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
......@@ -165,7 +172,10 @@ if(${_result})
endif()
set(generated_static_files
"${generated_op_path}"
"${generated_op_path_1}"
"${generated_op_path_2}"
"${generated_op_path_3}"
"${generated_op_path_4}"
"${generated_static_op_path}"
"${generated_sparse_ops_path}"
"${generated_argument_mapping_path}"
......@@ -192,6 +202,16 @@ foreach(generated_static_file ${generated_static_files})
endif()
endforeach()
# Note(zyfncg): The generated file generated_op.cc has been deleted,
# so we need to clear the generated_op.cc and generated_op.cc.tmp cached in develop environment.
set(old_generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
if(EXISTS "${old_generated_op_path}" OR EXISTS "${old_generated_op_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E remove -f "${old_generated_op_path}"
"${old_generated_op_path}.tmp")
endif()
# op extra info file
set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/ops_extra_info_gen.py)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import argparse
import math
import os
from pathlib import Path
......@@ -478,6 +479,29 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict):
] = False
def split_ops_list(ops, backward_op_dict, split_num):
new_ops_list = []
new_bw_ops_list = []
list_size = math.ceil(len(ops) / split_num)
tmp_ops_list = []
tmp_bw_ops_list = []
for idx, op in enumerate(ops):
tmp_ops_list.append(op)
current_op = op
while (
'backward' in current_op
and current_op['backward'] in backward_op_dict
):
tmp_bw_ops_list.append(backward_op_dict[current_op['backward']])
current_op = backward_op_dict[current_op['backward']]
if (idx + 1) % list_size == 0 or idx == len(ops) - 1:
new_ops_list.append(tmp_ops_list)
new_bw_ops_list.append(tmp_bw_ops_list)
tmp_ops_list = []
tmp_bw_ops_list = []
return new_ops_list, new_bw_ops_list
def main(
ops_yaml_path,
backward_yaml_path,
......@@ -548,13 +572,23 @@ def main(
os.remove(output_arg_map_path)
return
op_template = env.get_template('op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(
ops=ops,
backward_ops=backward_ops,
op_dict=op_dict,
)
f.write(msg)
backward_fluid_op_dict = {}
for bw_op in backward_ops:
backward_fluid_op_dict[bw_op['op_name']] = bw_op
output_op_files_num = len(output_op_path)
new_ops_list, new_bw_ops_list = split_ops_list(
ops, backward_fluid_op_dict, output_op_files_num
)
for idx, output_op_file in enumerate(output_op_path):
with open(output_op_file, "wt") as f:
msg = op_template.render(
ops=new_ops_list[idx],
backward_ops=new_bw_ops_list[idx],
op_dict=op_dict,
)
f.write(msg)
ks_template = env.get_template('ks.c.j2')
with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(ops=ops, backward_ops=backward_ops)
......@@ -578,7 +612,10 @@ if __name__ == "__main__":
'--op_version_yaml_path', type=str, help="ops version yaml file."
)
parser.add_argument(
"--output_op_path", type=str, help="path to save generated operators."
"--output_op_path",
type=str,
nargs='+',
help="path to save generated operators.",
)
parser.add_argument(
"--output_arg_map_path",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册