operators.cmake 20.3 KB
Newer Older
1 2
# CMake file `unity_build` is used to handle Unity Build compilation.
include(unity_build)
W
Wu Yi 已提交
3 4 5 6 7 8 9
set(PART_CUDA_KERNEL_FILES)
function(op_library TARGET)
    # op_library is a function to create op library. The interface is same as
    # cc_library. But it handle split GPU/CPU code and link some common library
    # for ops.
    set(cc_srcs)
    set(cu_srcs)
10
    set(hip_srcs)
W
Wu Yi 已提交
11
    set(cu_cc_srcs)
12
    set(hip_cc_srcs)
13
    set(xpu_cc_srcs)
14
    set(npu_cc_srcs)
F
fwenguang 已提交
15
    set(mlu_cc_srcs)
W
Wu Yi 已提交
16
    set(cudnn_cu_cc_srcs)
17
    set(miopen_cu_cc_srcs)
L
liym27 已提交
18
    set(cudnn_cu_srcs)
19
    set(miopen_cu_srcs)
W
Wu Yi 已提交
20
    set(CUDNN_FILE)
21
    set(MIOPEN_FILE)
W
Wu Yi 已提交
22 23
    set(mkldnn_cc_srcs)
    set(MKLDNN_FILE)
24
    set(op_common_deps operator op_registry math_function layer common_infer_shape_functions)
25 26 27
    if (WITH_ASCEND_CL)
      set(op_common_deps ${op_common_deps} npu_op_runner)
    endif()
F
fwenguang 已提交
28 29 30 31
    if (WITH_MLU)
      set(op_common_deps ${op_common_deps} mlu_baseop)
    endif()

32 33
    # Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build.
    set(options UNITY)
W
Wu Yi 已提交
34 35 36 37 38 39 40 41 42 43 44
    set(oneValueArgs "")
    set(multiValueArgs SRCS DEPS)
    set(pybind_flag 0)
    cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
            "${multiValueArgs}" ${ARGN})

    list(LENGTH op_library_SRCS op_library_SRCS_len)
    if (${op_library_SRCS_len} EQUAL 0)
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
            list(APPEND cc_srcs ${TARGET}.cc)
        endif()
45 46 47 48 49 50 51
        if(WITH_GPU)
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
                list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
            endif()
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
                list(APPEND cu_srcs ${TARGET}.cu)
            endif()
52 53 54
            if (WITH_NV_JETSON)
                list(REMOVE_ITEM cu_srcs "decode_jpeg_op.cu")
            endif()
55 56 57 58 59 60 61 62 63 64 65 66
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
                set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
                        ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
                list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
            endif()
            string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
                list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
            endif()
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu)
                list(APPEND cudnn_cu_srcs ${CUDNN_FILE}.cu)
            endif()
L
liym27 已提交
67
        endif()
68 69 70
        if(WITH_ROCM)
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
                list(APPEND hip_cc_srcs ${TARGET}.cu.cc)
Y
Y_Xuan 已提交
71
            endif()
72 73
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
                list(APPEND hip_srcs ${TARGET}.cu)
Y
Y_Xuan 已提交
74
            endif()
75 76
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
                set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
Y
Y_Xuan 已提交
77
                        ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
78
                list(APPEND hip_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
Y
Y_Xuan 已提交
79
            endif()
80 81 82
            string(REPLACE "_op" "_cudnn_op" MIOPEN_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.cu.cc)
                list(APPEND miopen_cu_cc_srcs ${MIOPEN_FILE}.cu.cc)
Y
Y_Xuan 已提交
83
            endif()
84 85
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.cu)
                list(APPEND miopen_cu_srcs ${MIOPEN_FILE}.cu)
W
Wu Yi 已提交
86 87 88 89
            endif()
        endif()
        if(WITH_MKLDNN)
            string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}")
90 91
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mkldnn/${MKLDNN_FILE}.cc)
                list(APPEND mkldnn_cc_srcs mkldnn/${MKLDNN_FILE}.cc)
W
Wu Yi 已提交
92 93
            endif()
        endif()
94
        if(WITH_XPU)
95 96 97
            string(REPLACE "_op" "_op_xpu" XPU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${XPU_FILE}.cc)
                list(APPEND xpu_cc_srcs ${XPU_FILE}.cc)
98 99
            endif()
        endif()
100 101 102 103 104 105
        if(WITH_ASCEND_CL)
            string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${NPU_FILE}.cc)
                list(APPEND npu_cc_srcs ${NPU_FILE}.cc)
            endif()
        endif()
F
fwenguang 已提交
106 107 108 109 110 111
        if(WITH_MLU)
            string(REPLACE "_op" "_op_mlu" MLU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MLU_FILE}.cc)
                list(APPEND mlu_cc_srcs ${MLU_FILE}.cc)
            endif()
        endif()
W
Wu Yi 已提交
112 113
    else()
        foreach(src ${op_library_SRCS})
114 115 116 117 118 119 120 121
            if(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu$")
                list(APPEND miopen_cu_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu$")
                list(APPEND hip_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu.cc$")
                list(APPEND miopen_cu_cc_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu.cc$")
                list(APPEND hip_cc_srcs ${src})
L
liym27 已提交
122 123
            elseif(${src} MATCHES ".*_cudnn_op.cu$")
                list(APPEND cudnn_cu_srcs ${src})
W
Wu Yi 已提交
124 125 126 127 128 129 130 131
            elseif (${src} MATCHES ".*\\.cu$")
                list(APPEND cu_srcs ${src})
            elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
                list(APPEND cudnn_cu_cc_srcs ${src})
            elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
                list(APPEND mkldnn_cc_srcs ${src})
            elseif(${src} MATCHES ".*\\.cu.cc$")
                list(APPEND cu_cc_srcs ${src})
132
            elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$")
133
                list(APPEND xpu_cc_srcs ${src})
134 135
            elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
                list(APPEND npu_cc_srcs ${src})
F
fwenguang 已提交
136 137
            elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$")
                list(APPEND mlu_cc_srcs ${src})
W
Wu Yi 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151
            elseif(${src} MATCHES ".*\\.cc$")
                list(APPEND cc_srcs ${src})
            else()
                message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu")
            endif()
        endforeach()
    endif()

    list(LENGTH cc_srcs cc_srcs_len)
    if (${cc_srcs_len} EQUAL 0)
        message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
    endif()
    if (WIN32)
    # remove windows unsupported op, because windows has no nccl, no warpctc such ops.
P
peizhilin 已提交
152
    foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op")
W
Wu Yi 已提交
153 154 155 156 157
        if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
          return()
        endif()
    endforeach()
    endif(WIN32)
158 159 160 161 162 163 164 165 166 167 168 169 170

    # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
    if(WITH_UNITY_BUILD AND op_library_UNITY)
        # Generate the unity target name by the directory where source files located.
        string(REPLACE "${PADDLE_SOURCE_DIR}/paddle/fluid/" "" UNITY_TARGET ${CMAKE_CURRENT_SOURCE_DIR})
        string(REPLACE "/" "_" UNITY_TARGET ${UNITY_TARGET})
        set(UNITY_TARGET "paddle_${UNITY_TARGET}_unity")
        if(NOT ${UNITY_TARGET} IN_LIST OP_LIBRARY)
            set(OP_LIBRARY ${UNITY_TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
        endif()
    else()
        set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
    endif()
W
Wu Yi 已提交
171 172 173 174 175 176

    list(LENGTH op_library_DEPS op_library_DEPS_len)
    if (${op_library_DEPS_len} GREATER 0)
        set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
    endif()
    if (WITH_GPU)
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
        # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
        if(WITH_UNITY_BUILD AND op_library_UNITY)
            # Combine the cc and cu source files.
            compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs})
            compose_unity_target_sources(${UNITY_TARGET} cu ${cudnn_cu_srcs} ${cu_srcs})
            if(TARGET ${UNITY_TARGET})
                # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
                target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources} ${unity_target_cu_sources})
            else()
                # If `UNITY_TARGET` does not exist, create `UNITY_TARGET` with source files.
                nv_library(${UNITY_TARGET} SRCS ${unity_target_cc_sources} ${unity_target_cu_sources} DEPS ${op_library_DEPS} ${op_common_deps})
            endif()
            # Add alias library to handle dependencies.
            add_library(${TARGET} ALIAS ${UNITY_TARGET})
        else()
            nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cudnn_cu_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
W
Wu Yi 已提交
193
                ${op_common_deps})
194
        endif()
195 196 197 198
    elseif (WITH_ROCM)
        list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
        list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
        list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
Z
zhiboniu 已提交
199
        list(REMOVE_ITEM hip_srcs "cholesky_solve_op.cu")
Z
zhiboniu 已提交
200
        list(REMOVE_ITEM hip_srcs "lu_op.cu")
201
        list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
202
        list(REMOVE_ITEM hip_srcs "svd_op.cu")
203
        list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu")
204
        list(REMOVE_ITEM hip_srcs "qr_op.cu")
205
        list(REMOVE_ITEM hip_srcs "eigh_op.cu")
206
        list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
207
        list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
208
        hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
W
Wu Yi 已提交
209 210
                ${op_common_deps})
    else()
211 212 213
        # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
        if(WITH_UNITY_BUILD AND op_library_UNITY)
            # Combine the cc source files.
F
fwenguang 已提交
214
            compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs})
215 216 217 218 219 220 221 222 223 224
            if(TARGET ${UNITY_TARGET})
                # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
                target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources})
            else()
                # If `UNITY_TARGET` does not exist, create `UNITY_TARGET` with source files.
                cc_library(${UNITY_TARGET} SRCS ${unity_target_cc_sources} DEPS ${op_library_DEPS} ${op_common_deps})
            endif()
            # Add alias library to handle dependencies.
            add_library(${TARGET} ALIAS ${UNITY_TARGET})
        else()
F
fwenguang 已提交
225
            cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs} DEPS ${op_library_DEPS}
226 227
                ${op_common_deps})
        endif()
W
Wu Yi 已提交
228 229 230
    endif()

    # Define operators that don't need pybind here.
231
    foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op"
232
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
233
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
234
"sync_batch_norm_op" "sparse_attention_op"  "dgc_op" "fused_fc_elementwise_layernorm_op"
235
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
236 237
"fused_bn_add_activation_op" "fused_attention_op" "resnet_unit_op" "fused_feedforward_op")

W
Wu Yi 已提交
238 239 240 241 242 243 244 245
        if ("${TARGET}" STREQUAL "${manual_pybind_op}")
            set(pybind_flag 1)
        endif()
    endforeach()

    # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
    # Note that it's enough to just adding one operator to pybind in a *_op.cc file.
    # And for detail pybind information, please see generated paddle/pybind/pybind.h.
246
    set(ORIGINAL_TARGET ${TARGET})
W
Wu Yi 已提交
247 248
    file(READ ${TARGET}.cc TARGET_CONTENT)
    string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
249 250 251
    # [ \t\r\n]* is used for blank characters
    string(REGEX MATCH "REGISTER_OPERATOR\\([ \t\r\n]*[a-z0-9_]*," one_register "${multi_register}")

W
Wu Yi 已提交
252 253 254 255 256
    if (one_register STREQUAL "")
        string(REPLACE "_op" "" TARGET "${TARGET}")
    else ()
        string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}")
        string(REPLACE "," "" TARGET "${TARGET}")
257 258 259
        # [ \t\r\n]+ is used for blank characters.
        # Here we use '+' instead of '*' since it is a REPLACE operation.
        string(REGEX REPLACE "[ \t\r\n]+" "" TARGET "${TARGET}")
W
Wu Yi 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272
    endif()

    # pybind USE_NO_KERNEL_OP
    # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
    string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
    string(REPLACE "_op" "" TARGET "${TARGET}")
    if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
        file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
        set(pybind_flag 1)
    endif()

    # pybind USE_CPU_ONLY_OP
    list(LENGTH cu_srcs cu_srcs_len)
273
    list(LENGTH hip_srcs hip_srcs_len)
W
Wu Yi 已提交
274
    list(LENGTH cu_cc_srcs cu_cc_srcs_len)
275
    list(LENGTH hip_cc_srcs hip_cc_srcs_len)
W
Wu Yi 已提交
276
    list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
277
    list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
278
    list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
279
    list(LENGTH npu_cc_srcs npu_cc_srcs_len)
F
fwenguang 已提交
280
    list(LENGTH mlu_cc_srcs mlu_cc_srcs_len)
W
Wu Yi 已提交
281
    if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
F
fwenguang 已提交
282 283
        ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND
        ${npu_cc_srcs_len} EQUAL 0 AND ${mlu_cc_srcs_len} EQUAL 0)
W
Wu Yi 已提交
284 285 286 287 288 289 290
        file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
        set(pybind_flag 1)
    endif()

    # pybind USE_OP_DEVICE_KERNEL for CUDNN
    list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
    if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
291 292 293
      if(${TARGET} STREQUAL "activation")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n")
      else()
W
Wu Yi 已提交
294
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
295
      endif()
W
Wu Yi 已提交
296 297 298
    endif()

    # pybind USE_OP_DEVICE_KERNEL for MIOPEN
299 300
    list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
    if (WITH_ROCM AND ${miopen_cu_cc_srcs_len} GREATER 0)
Y
Y_Xuan 已提交
301 302 303
      if(${TARGET} STREQUAL "activation")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n")
      else()
304
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
Y
Y_Xuan 已提交
305 306 307
      endif()
    endif()

308 309 310 311 312 313
    # pybind USE_OP_DEVICE_KERNEL for CUDNN
    list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len)
    if (WITH_GPU AND ${cudnn_cu_srcs_len} GREATER 0)
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
    endif()

Y
Y_Xuan 已提交
314
    # pybind USE_OP_DEVICE_KERNEL for MIOPEN
315 316 317
    list(LENGTH miopen_cu_srcs miopen_cu_srcs_len)
    if (WITH_ROCM AND ${miopen_cu_srcs_len} GREATER 0)
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
W
Wu Yi 已提交
318 319
    endif()

320
    if (WITH_XPU AND ${pybind_flag} EQUAL 0 AND ${xpu_cc_srcs_len} GREATER 0)
321 322
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n")
    endif()
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341

    if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0)
        file(READ ${ORIGINAL_TARGET}_npu.cc TARGET_NPU_CONTENT)
        # It is different from the logic above, becareful
        string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\(.*" multi_npu_register "${TARGET_NPU_CONTENT}")
        # [ \t\r\n]* is used for blank characters
        string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_npu_register "${multi_npu_register}")

        if (one_npu_register STREQUAL "")
            string(REPLACE "_op" "" NPU_TARGET "${TARGET}")
        else ()
            string(REPLACE "REGISTER_OP_NPU_KERNEL(" "" NPU_TARGET "${one_npu_register}")
            string(REPLACE "," "" NPU_TARGET "${NPU_TARGET}")
            # [ \t\r\n]+ is used for blank characters.
            # Here we use '+' instead of '*' since it is a REPLACE operation.
            string(REGEX REPLACE "[ \t\r\n]+" "" NPU_TARGET "${NPU_TARGET}")
        endif()
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${NPU_TARGET}, NPU);\n")
    endif()
F
fwenguang 已提交
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
    if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0)
        file(READ ${ORIGINAL_TARGET}_mlu.cc TARGET_MLU_CONTENT)
        # It is different from the logic above, becareful
        string(REGEX MATCH "REGISTER_OP_MLU_KERNEL\\(.*" multi_mlu_register "${TARGET_MLU_CONTENT}")
        # [ \t\r\n]* is used for blank characters
        string(REGEX MATCH "REGISTER_OP_MLU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_mlu_register "${multi_mlu_register}")

        if (one_mlu_register STREQUAL "")
            string(REPLACE "_op" "" MLU_TARGET "${TARGET}")
        else ()
            string(REPLACE "REGISTER_OP_MLU_KERNEL(" "" MLU_TARGET "${one_mlu_register}")
            string(REPLACE "," "" MLU_TARGET "${MLU_TARGET}")
            # [ \t\r\n]+ is used for blank characters.
            # Here we use '+' instead of '*' since it is a REPLACE operation.
            string(REGEX REPLACE "[ \t\r\n]+" "" MLU_TARGET "${MLU_TARGET}")
        endif()
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MLU_TARGET}, MLU);\n")
    endif()
360

W
Wu Yi 已提交
361 362 363 364 365
    # pybind USE_OP_DEVICE_KERNEL for MKLDNN
    if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
      # Append first implemented MKLDNN activation operator
      if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
X
Xin Pan 已提交
366 367
      elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
368 369
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
M
Michał Gallus 已提交
370 371 372
      elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n")
373
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n")
M
Michał Gallus 已提交
374 375 376 377
      elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n")
W
Wu Yi 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
      else()
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
      endif()
    endif()

    # pybind USE_OP
    if (${pybind_flag} EQUAL 0)
      # NOTE(*): activation use macro to regist the kernels, set use_op manually.
      if(${TARGET} STREQUAL "activation")
        file(APPEND ${pybind_file} "USE_OP(relu);\n")
      elseif(${TARGET} STREQUAL "fake_dequantize")
        file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
      elseif(${TARGET} STREQUAL "fake_quantize")
        file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
      elseif(${TARGET} STREQUAL "tensorrt_engine_op")
          message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
      else()
        file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
      endif()
    endif()
endfunction()


function(register_operators)
    set(options "")
    set(oneValueArgs "")
W
Wu Yi 已提交
404
    set(multiValueArgs EXCLUDES DEPS)
W
Wu Yi 已提交
405 406 407 408
    cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}"
            "${multiValueArgs}" ${ARGN})
    file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
    string(REPLACE "_mkldnn" "" OPS "${OPS}")
409
    string(REPLACE "_xpu" "" OPS "${OPS}")
410
    string(REPLACE "_npu" "" OPS "${OPS}")
F
fwenguang 已提交
411
    string(REPLACE "_mlu" "" OPS "${OPS}")
W
Wu Yi 已提交
412 413
    string(REPLACE ".cc" "" OPS "${OPS}")
    list(REMOVE_DUPLICATES OPS)
W
Wu Yi 已提交
414
    list(LENGTH register_operators_DEPS register_operators_DEPS_len)
W
Wu Yi 已提交
415 416 417 418

    foreach(src ${OPS})
        list(FIND register_operators_EXCLUDES ${src} _index)
        if (${_index} EQUAL -1)
W
Wu Yi 已提交
419
            if (${register_operators_DEPS_len} GREATER 0)
420
                op_library(${src} UNITY DEPS ${register_operators_DEPS})
W
Wu Yi 已提交
421
            else()
422
                op_library(${src} UNITY)
W
Wu Yi 已提交
423
            endif()
W
Wu Yi 已提交
424 425
        endif()
    endforeach()
426 427 428 429 430 431 432 433

    # Complete the processing of `UNITY_TARGET`.
    if(WITH_UNITY_BUILD)
        finish_unity_target(cc)
        if(WITH_GPU)
            finish_unity_target(cu)
        endif()
    endif()
W
Wu Yi 已提交
434
endfunction()