diff --git a/paddle/cinn/hlir/dialect/.gitignore b/paddle/cinn/hlir/dialect/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a21ba08d95acf381ab945cefc294161c6d4b3a8d --- /dev/null +++ b/paddle/cinn/hlir/dialect/.gitignore @@ -0,0 +1 @@ +generated/** diff --git a/paddle/cinn/hlir/dialect/CMakeLists.txt b/paddle/cinn/hlir/dialect/CMakeLists.txt index d7c6d787a7fb0437c5727757c2f1a150851d7ea3..68798ede2e5527e31fcf6fb0e9e95cee937df3ed 100755 --- a/paddle/cinn/hlir/dialect/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/CMakeLists.txt @@ -5,16 +5,23 @@ if(NOT CINN_ONLY) set(CINN_DIALECT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/cinn/hlir/dialect") # Generate cinn_dialect files defining op using op_gen_file + set(cinn_op_gen_parsed_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parse_op.py) + set(cinn_op_gen_file ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/op_gen.py) set(cinn_op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) - set(cinn_op_forward_yaml_file1 - ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/cinn_ops.parsed.yaml) + set(cinn_op_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/cinn_ops.yaml) + + set(parsed_op_dir ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/generated) + + set(cinn_op_parsed_yaml_file ${parsed_op_dir}/cinn_ops.parsed.yaml) - set(cinn_op_yaml_files ${cinn_op_forward_yaml_file1}) + set(cinn_op_parsed_yaml_files ${cinn_op_parsed_yaml_file}) set(cinn_op_namespace cinn,dialect) set(cinn_dialect_name cinn) @@ -23,19 +30,26 @@ if(NOT CINN_ONLY) set(cinn_op_header_file_tmp ${cinn_op_header_file}.tmp) set(cinn_op_source_file_tmp ${cinn_op_source_file}.tmp) + add_custom_command( + OUTPUT ${cinn_op_parsed_yaml_file} + COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir} + COMMAND ${PYTHON_EXECUTABLE} ${cinn_op_gen_parsed_yaml_file} --op_yaml_path + ${cinn_op_yaml_file} --output_path ${cinn_op_parsed_yaml_file} + VERBATIM) + add_custom_command( OUTPUT ${cinn_op_header_file} ${cinn_op_source_file} COMMAND ${PYTHON_EXECUTABLE} ${cinn_op_gen_file} --op_yaml_files - ${cinn_op_yaml_files} --op_compat_yaml_file ${cinn_op_compat_yaml_file} - --namespaces ${cinn_op_namespace} --dialect_name ${cinn_dialect_name} - --op_def_h_file ${cinn_op_header_file_tmp} --op_def_cc_file - ${cinn_op_source_file_tmp} + ${cinn_op_parsed_yaml_files} --op_compat_yaml_file + ${cinn_op_compat_yaml_file} --namespaces ${cinn_op_namespace} + --dialect_name ${cinn_dialect_name} --op_def_h_file + ${cinn_op_header_file_tmp} --op_def_cc_file ${cinn_op_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${cinn_op_header_file_tmp} ${cinn_op_header_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${cinn_op_source_file_tmp} ${cinn_op_source_file} - DEPENDS ${cinn_op_gen_file} ${cinn_op_forward_yaml_file1} + DEPENDS ${cinn_op_gen_file} ${cinn_op_parsed_yaml_file} ${cinn_op_compat_yaml_file} VERBATIM) diff --git a/paddle/cinn/hlir/dialect/cinn_ops.parsed.yaml b/paddle/cinn/hlir/dialect/cinn_ops.parsed.yaml deleted file mode 100644 index 6297be83e23986b871b29217738b90c258a7ac50..0000000000000000000000000000000000000000 --- a/paddle/cinn/hlir/dialect/cinn_ops.parsed.yaml +++ /dev/null @@ -1,78 +0,0 @@ -- name: add - inputs: - - typename: Tensor - name: x - optional: false - no_need_buffer: false - data_transform: {} - - typename: Tensor - name: y - optional: false - no_need_buffer: false - data_transform: {} - attrs: [] - outputs: - - {typename: Tensor, name: out, optional: false, intermediate: false} - no_need_buffer: null - data_transform: null - infer_meta: - func: ElementwiseInferMeta - param: [x, y] - kernel: - func: [add] - param: [x, y] - backend: null - layout: null - data_type: null - dispatch: {add: null} - force_backend: null - inplace: {out: x} - view: null - backward: add_grad -- name: add_grad - inputs: - - typename: Tensor - name: x - optional: false - no_need_buffer: true - data_transform: {} - - typename: Tensor - name: y - optional: false - no_need_buffer: true - data_transform: {} - - typename: Tensor - name: out_grad - optional: false - no_need_buffer: false - data_transform: {} - attrs: - - {typename: int, name: axis, default_value: '-1'} - outputs: - - {typename: Tensor, name: x_grad, optional: false, intermediate: false} - - {typename: Tensor, name: y_grad, optional: false, intermediate: false} - no_need_buffer: [x, y] - data_transform: null - infer_meta: - func: GeneralBinaryGradInferMeta - param: [x, y] - kernel: - func: [add_grad] - param: [x, y, out_grad, axis] - backend: null - layout: null - data_type: null - dispatch: {add_grad: null} - force_backend: null - inplace: {x_grad: out_grad} - view: null - composite: {func_name: add_grad, func_args: 'x, y, out_grad, axis, x_grad, y_grad'} - backward: add_double_grad - forward: - name: add - inputs: - - {name: x, typename: Tensor} - - {name: y, typename: Tensor} - attrs: [] - outputs: - - {name: out, typename: Tensor} diff --git a/paddle/cinn/hlir/dialect/cinn_ops.yaml b/paddle/cinn/hlir/dialect/cinn_ops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..096d2c4e652b17f64e7214b246c72952e6bff244 --- /dev/null +++ b/paddle/cinn/hlir/dialect/cinn_ops.yaml @@ -0,0 +1,8 @@ +- op : add + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta + kernel : + func : add + inplace : (x -> out) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index b3858b6e6730e6555ca0d7697c66c8c759813344..f3cf7982aa9c67690361a70ce20f4c795b6becc9 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -1072,21 +1072,24 @@ def OpGenerator( # =================================== # # generate op vjp function str - op_vjp_str = '' - # TODO(chenzhiyang) add vjp gen code - if ( - op_info.backward_name - and op_info.op_phi_name[0] - in vjp_interface_implementation_gen_op_list - ): - op_vjp_str = gen_op_vjp_str( - op_class_name, - op_info.backward_name, - op_name, - op_info_items[op_info.op_phi_name[0]], - op_info_items[op_info.backward_name], - ) + op_vjp_str = '' + if dialect_name == "cinn": + logging.warning("cinn is currently not support Vjp function") + else: + # TODO(chenzhiyang) add vjp gen code + if ( + op_info.backward_name + and op_info.op_phi_name[0] + in vjp_interface_implementation_gen_op_list + ): + op_vjp_str = gen_op_vjp_str( + op_class_name, + op_info.backward_name, + op_name, + op_info_items[op_info.op_phi_name[0]], + op_info_items[op_info.backward_name], + ) ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) @@ -1100,7 +1103,7 @@ def OpGenerator( ops_defined_list.append(op_infer_meta_str) # NOTE(chenxi67)skip if dialect_name==cinn if dialect_name == "cinn": - logging.warning("cinn is currently not support Vjp function") + pass else: ops_vjp_defined_list.append(op_vjp_str)