set(pass_file
    ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h.tmp)
set(pass_file_final
    ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(
  WRITE ${pass_file}
  "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt.  DO NOT EDIT!\n\n"
)
file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")

copy_if_different(${pass_file} ${pass_file_final})

add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
if(NOT APPLE
   AND NOT WIN32
   AND (WITH_GPU OR WITH_ROCM))
  add_subdirectory(fusion_group)
endif()

# Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable
function(pass_library TARGET DEST)
  set(options "")
  set(oneValueArgs "")
  set(multiValueArgs SRCS DEPS DIR)
  set(targetPrefix "")

  cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}"
                        "${multiValueArgs}" ${ARGN})
  if(pass_library_DIR)
    cc_library(
      ${TARGET}
      SRCS ${pass_library_DIR}/${TARGET}.cc
      DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
           ${pass_library_DEPS})
  else()
    cc_library(
      ${TARGET}
      SRCS ${TARGET}.cc
      DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
           ${pass_library_DEPS})
  endif()

  # add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
  if(${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
    if(NOT CMAKE_BUILD_TYPE STREQUAL "Release")
      message(STATUS "add pass ${TARGET} ${DEST}")
    endif()
    file(APPEND ${pass_file} "USE_PASS(${TARGET});\n")
    set(INFER_IR_PASSES
        ${INFER_IR_PASSES} ${TARGET}
        CACHE INTERNAL "")
  endif()
endfunction()

cc_library(
  node
  SRCS node.cc
  DEPS proto_desc)
cc_library(
  graph
  SRCS graph.cc
  DEPS node pretty_log)
cc_library(
  graph_helper
  SRCS graph_helper.cc
  DEPS graph program_utils scale_loss_grad_op_handle)
cc_library(
  pass
  SRCS pass.cc
  DEPS graph node graph_helper)
cc_library(
  graph_traits
  SRCS graph_traits.cc
  DEPS graph)
cc_library(
  cost_model
  SRCS cost_model.cc
  DEPS executor graph profiler proto_desc device_tracer)

set(GRAPH_PATTERN_DETECTOR_DEPS graph graph_helper graph_traits)
if(WITH_TESTING)
  set(GRAPH_PATTERN_DETECTOR_DEPS ${GRAPH_PATTERN_DETECTOR_DEPS} gtest)
endif()
cc_library(
  graph_pattern_detector
  SRCS graph_pattern_detector.cc
  DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})

cc_library(
  op_compat_sensible_pass
  SRCS op_compat_sensible_pass.cc
  DEPS graph_pattern_detector op_def_api pass)
cc_library(
  subgraph_detector
  SRCS subgraph_detector.cc
  DEPS graph_pattern_detector executor)
cc_library(
  fuse_pass_base
  SRCS fuse_pass_base.cc
  DEPS op_compat_sensible_pass)
cc_library(
  placement_pass_base
  SRCS placement_pass_base.cc
  DEPS pass)

cc_library(
  coalesce_grad_tensor_pass
  SRCS coalesce_grad_tensor_pass.cc
  DEPS graph graph_helper)

pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(vit_attention_fuse_pass inference)
pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference)
pass_library(seqpool_cvm_concat_fuse_pass inference)
pass_library(repeated_fc_relu_fuse_pass inference)
pass_library(squared_mat_sub_fuse_pass inference)
pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)
pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(yolo_box_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(dense_fc_to_sparse_pass inference)
pass_library(dense_multihead_matmul_to_sparse_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)

if(WITH_TENSORRT)
  pass_library(trt_map_matmul_to_mul_pass inference)
  pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
  pass_library(trt_multihead_matmul_fuse_pass inference)
  pass_library(trt_skip_layernorm_fuse_pass inference)
  pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
  pass_library(preln_skip_layernorm_fuse_pass inference)
  pass_library(set_transformer_input_convert_pass inference)
  pass_library(remove_padding_recover_padding_pass inference)
  pass_library(delete_remove_padding_recover_padding_pass inference)
endif()

if(WITH_GPU OR WITH_ROCM)
  pass_library(cudnn_placement_pass base DEPS placement_pass_base)
  pass_library(embedding_eltwise_layernorm_fuse_pass inference)
endif()

if(WITH_MKLDNN)
  pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn)
  pass_library(
    mkldnn_inplace_pass
    inference
    DEPS
    mkldnn_placement_pass
    op_registry
    elementwise_add_op
    gelu_op
    activation_op
    softmax_op
    softmax
    DIR
    mkldnn)
  pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn)
  pass_library(conv_affine_channel_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(int8_scale_calculation_mkldnn_pass inference DIR mkldnn)
  pass_library(params_quantization_mkldnn_pass inference DIR mkldnn)
  pass_library(fc_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(scale_matmul_fuse_pass inference DIR mkldnn)
  pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn)
  pass_library(cpu_bfloat16_pass inference DIR mkldnn)
  pass_library(fc_mkldnn_pass inference DIR mkldnn)
  pass_library(interpolate_mkldnn_pass inference DIR mkldnn)
  pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn)
  pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(cpu_quantize_placement_pass base DIR mkldnn)
  pass_library(cpu_quantize_pass inference DIR mkldnn)
  pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
  pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(matmul_transpose_reshape_mkldnn_fuse_pass inference DIR mkldnn)
  pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
  pass_library(multi_gru_fuse_pass inference DIR mkldnn)
  pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
  pass_library(quant_dequant_mkldnn_pass inference DIR mkldnn)
  pass_library(compute_propagate_scales_mkldnn_pass inference DIR mkldnn)
endif()

if(WITH_IPU)
  pass_library(forward_graph_extract_pass base DIR ipu)
  pass_library(optimizer_extract_pass base DIR ipu)
  pass_library(optimizer_state_align_pass base DIR ipu)
  pass_library(ipu_graph_builder_pass base DIR ipu)
  pass_library(ipu_runtime_replacer_pass base DIR ipu)
  pass_library(inference_process_pass base DIR ipu)
  pass_library(inference_postprocess_pass base DIR ipu)
  pass_library(popart_canonicalization_pass base DIR ipu)
  pass_library(ipu_inplace_pass base DIR ipu)
  pass_library(infer_shape_pass base DIR ipu)
  pass_library(delete_scale_op_pass base DIR ipu)
  pass_library(avg_shard_pass base DIR ipu)
  pass_library(inference_dtype_transfer_pass base DIR ipu)
endif()

cc_library(
  fuse_bn_act_pass
  SRCS fuse_bn_act_pass.cc
  DEPS pass graph_pattern_detector)
cc_library(
  fuse_bn_add_act_pass
  SRCS fuse_bn_add_act_pass.cc
  DEPS pass graph_pattern_detector)
cc_library(
  fuse_elewise_add_act_pass
  SRCS fuse_elewise_add_act_pass.cc
  DEPS pass graph_pattern_detector)
cc_library(
  fuse_gemm_epilogue_pass
  SRCS fuse_gemm_epilogue_pass.cc
  DEPS pass graph_pattern_detector)
cc_library(
  fuse_relu_depthwise_conv_pass
  SRCS fuse_relu_depthwise_conv_pass.cc
  DEPS pass graph_pattern_detector)

set(GLOB_PASS_LIB
    ${PASS_LIBRARY}
    CACHE INTERNAL "Global PASS library")

cc_library(
  pass_builder
  SRCS pass_builder.cc
  DEPS pass)
cc_library(
  pass_test_util
  SRCS pass_test_util.cc
  DEPS graph pass)

cc_test(
  node_test
  SRCS node_test.cc
  DEPS node)
cc_test(
  pass_test
  SRCS pass_test.cc
  DEPS graph pass graph_helper)
cc_test(
  graph_test
  SRCS graph_test.cc
  DEPS graph graph_helper op_registry)
cc_test(
  graph_helper_test
  SRCS graph_helper_test.cc
  DEPS graph graph_helper op_registry)
cc_test(
  graph_to_program_pass_test
  SRCS graph_to_program_pass_test.cc
  DEPS graph_to_program_pass)
cc_test(
  cost_model_test
  SRCS cost_model_test.cc
  DEPS cost_model op_registry)
cc_test(
  test_graph_pattern_detector
  SRCS graph_pattern_detector_tester.cc
  DEPS graph_pattern_detector)
cc_test(
  test_op_compat_sensible_pass
  SRCS op_compat_sensible_pass_tester.cc
  DEPS op_compat_sensible_pass)
cc_test(
  test_fc_fuse_pass_cc
  SRCS fc_fuse_pass_tester.cc
  DEPS fc_fuse_pass framework_proto)
cc_test(
  test_fc_lstm_fuse_pass_cc
  SRCS fc_lstm_fuse_pass_tester.cc
  DEPS fc_lstm_fuse_pass framework_proto)
cc_test(
  test_fc_gru_fuse_pass_cc
  SRCS fc_gru_fuse_pass_tester.cc
  DEPS fc_gru_fuse_pass framework_proto)
cc_test(
  test_seqpool_concat_fuse_pass
  SRCS seqpool_concat_fuse_pass_tester.cc
  DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(
  test_seqpool_cvm_concat_fuse_pass
  SRCS seqpool_cvm_concat_fuse_pass_tester.cc
  DEPS seqpool_cvm_concat_fuse_pass framework_proto)
cc_test(
  test_repeated_fc_relu_fuse_pass_cc
  SRCS repeated_fc_relu_fuse_pass_tester.cc
  DEPS repeated_fc_relu_fuse_pass framework_proto)
cc_test(
  test_is_test_pass
  SRCS is_test_pass_tester.cc
  DEPS is_test_pass)
cc_test(
  test_simplify_with_basic_ops_pass
  SRCS simplify_with_basic_ops_pass_tester.cc
  DEPS simplify_with_basic_ops_pass)
cc_test(
  test_fc_elementwise_layernorm_fuse_pass_cc
  SRCS fc_elementwise_layernorm_fuse_pass_tester.cc
  DEPS fc_elementwise_layernorm_fuse_pass)
cc_test(
  test_skip_layernorm_fuse_pass
  SRCS skip_layernorm_fuse_pass_tester.cc
  DEPS skip_layernorm_fuse_pass)
cc_test(
  test_multihead_matmul_fuse_pass
  SRCS multihead_matmul_fuse_pass_tester.cc
  DEPS multihead_matmul_fuse_pass)
cc_test(
  test_conv_bn_fuse_pass_cc
  SRCS conv_bn_fuse_pass_tester.cc
  DEPS conv_bn_fuse_pass)
cc_test(
  test_adaptive_pool2d_convert_global_pass
  SRCS adaptive_pool2d_convert_global_pass_tester.cc
  DEPS adaptive_pool2d_convert_global_pass)
cc_test(
  test_unsqueeze2_eltwise_fuse_pass_cc
  SRCS unsqueeze2_eltwise_fuse_pass_tester.cc
  DEPS unsqueeze2_eltwise_fuse_pass)
cc_test(
  test_generate_pass_cc
  SRCS generate_pass_tester.cc
  DEPS generate_pass pass_desc_proto)
cc_test(
  test_delete_dropout_pass_cc
  SRCS delete_dropout_op_pass_test.cc
  DEPS delete_dropout_op_pass)
if(WITH_GPU OR WITH_ROCM)
  cc_test(
    test_embedding_eltwise_layernorm_fuse_pass
    SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc
    DEPS embedding_eltwise_layernorm_fuse_pass)
  cc_test(
    test_cudnn_placement_pass
    SRCS cudnn_placement_pass_tester.cc
    DEPS cudnn_placement_pass)
endif()
if(NOT WIN32)
  cc_test(
    test_sync_batch_norm_pass
    SRCS sync_batch_norm_pass_tester.cc
    DEPS sync_batch_norm_pass)
  cc_test(
    test_dense_fc_to_sparse_pass_cc
    SRCS dense_fc_to_sparse_pass_tester.cc
    DEPS fc_fuse_pass dense_fc_to_sparse_pass framework_proto)
  cc_test(
    test_dense_multihead_matmul_to_sparse_pass
    SRCS dense_multihead_matmul_to_sparse_pass_tester.cc
    DEPS multihead_matmul_fuse_pass dense_multihead_matmul_to_sparse_pass)
endif()
if(WITH_MKLDNN)
  cc_test(
    test_depthwise_conv_mkldnn_pass
    SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc
    DEPS depthwise_conv_mkldnn_pass)
  cc_test(
    test_conv_bias_mkldnn_fuse_pass_cc
    SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc
    DEPS conv_bias_mkldnn_fuse_pass naive_executor)
  cc_test(
    test_conv_activation_mkldnn_fuse_pass
    SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc
    DEPS conv_activation_mkldnn_fuse_pass)
  cc_test(
    test_conv_concat_relu_mkldnn_fuse_pass
    SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc
    DEPS conv_activation_mkldnn_fuse_pass)
  cc_test(
    test_conv_elementwise_add_mkldnn_fuse_pass
    SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
    DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util)
  cc_test(
    test_int8_scale_calculation_mkldnn_pass
    SRCS mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc
    DEPS int8_scale_calculation_mkldnn_pass pass_test_util)
  cc_test(
    test_params_quantization_mkldnn_pass
    SRCS mkldnn/params_quantization_mkldnn_pass_tester.cc
    DEPS params_quantization_mkldnn_pass)
  cc_test(
    test_fc_elementwise_add_mkldnn_fuse_pass
    SRCS mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc
    DEPS fc_elementwise_add_mkldnn_fuse_pass pass_test_util)
  cc_test(
    test_fc_act_mkldnn_fuse_pass
    SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
    DEPS fc_act_mkldnn_fuse_pass pass_test_util)
  cc_test(
    test_batch_norm_act_fuse_pass
    SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc
    DEPS batch_norm_act_fuse_pass pass_test_util)
  set(TEST_CONV_BN_PASS_DEPS
      conv_bn_fuse_pass
      graph_to_program_pass
      conv_op
      conv_transpose_op
      math_function
      im2col
      vol2col
      batch_norm_op
      gelu_op
      activation_op
      elementwise_add_op
      concat_and_split
      naive_executor
      device_context
      eigen_function)
  if(WITH_GPU OR WITH_ROCM)
    set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv)
  endif()
  cc_test(
    test_conv_batch_norm_mkldnn_fuse_pass
    SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc
    DEPS ${TEST_CONV_BN_PASS_DEPS})
  cc_test(
    test_scale_matmul_fuse_pass
    SRCS mkldnn/scale_matmul_fuse_pass_tester.cc
    DEPS scale_matmul_fuse_pass)
  cc_test(
    test_mkldnn_placement_pass
    SRCS mkldnn/mkldnn_placement_pass_tester.cc
    DEPS mkldnn_placement_pass)
  cc_test(
    test_mkldnn_inplace_pass
    SRCS mkldnn/mkldnn_inplace_pass_tester.cc
    DEPS mkldnn_inplace_pass)
  cc_test(
    test_compute_propagate_scales_mkldnn_pass
    SRCS mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc
    DEPS compute_propagate_scales_mkldnn_pass naive_executor)

  if(WITH_ONNXRUNTIME AND WIN32)
    # Copy onnxruntime for some c++ test in Windows, since the test will
    # be build only in CI, so suppose the generator in Windows is Ninja.
    copy_onnx(test_compute_propagate_scales_mkldnn_pass)
  endif()

  cc_test(
    test_cpu_quantize_placement_pass
    SRCS mkldnn/cpu_quantize_placement_pass_tester.cc
    DEPS cpu_quantize_placement_pass)
  cc_test(
    test_cpu_quantize_pass
    SRCS mkldnn/cpu_quantize_pass_tester.cc
    DEPS cpu_quantize_pass naive_executor)
  cc_test(
    test_cpu_quantize_squash_pass
    SRCS mkldnn/cpu_quantize_squash_pass_tester.cc
    DEPS cpu_quantize_squash_pass naive_executor)
  cc_test(
    test_reshape_transpose_matmul_mkldnn_fuse_pass
    SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc
    DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
  cc_test(
    test_matmul_transpose_reshape_fuse_pass
    SRCS mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc
    DEPS matmul_transpose_reshape_mkldnn_fuse_pass)
  cc_test(
    test_shuffle_channel_mkldnn_detect_pass
    SRCS mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc
    DEPS shuffle_channel_mkldnn_detect_pass)
  cc_test(
    test_cpu_bfloat16_placement_pass
    SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc
    DEPS cpu_bfloat16_placement_pass)
  cc_test(
    test_cpu_bfloat16_pass
    SRCS mkldnn/cpu_bfloat16_pass_tester.cc
    DEPS cpu_bfloat16_pass)
  cc_test(
    test_multi_gru_fuse_pass
    SRCS mkldnn/multi_gru_fuse_pass_tester.cc
    DEPS multi_gru_fuse_pass)
  cc_test(
    test_multi_gru_seq_fuse_pass
    SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc
    DEPS multi_gru_seq_fuse_pass)
  set(TEST_FC_RNN_PASS_DEPS fc_gru_fuse_pass fc_lstm_fuse_pass
                            mkldnn_placement_pass)
  cc_test(
    test_fc_rnn_mkldnn_fuse_pass
    SRCS mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc
    DEPS ${TEST_FC_RNN_PASS_DEPS})
endif()
