operators.cmake 22.1 KB
Newer Older
1 2
# CMake file `unity_build` is used to handle Unity Build compilation.
include(unity_build)
W
Wu Yi 已提交
3
set(PART_CUDA_KERNEL_FILES)
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

function(find_register FILENAME PATTERN OUTPUT)
# find the op_name of REGISTER_OPERATOR(op_name, ...), REGISTER_OP_CPU_KERNEL(op_name, ...) , etc.
# set op_name to OUTPUT
    set(options "")
    set(oneValueArgs "")
    set(multiValueArgs "")
    file(READ ${FILENAME} CONTENT)
    # message ("number of arguments sent to function: ${ARGC}")
    # message ("all function arguments:               ${ARGV}")
    # message("PATTERN ${PATTERN}")
    string(REGEX MATCH "${PATTERN}\\([ \t\r\n]*[a-z0-9_]*," register "${CONTENT}")
    if (NOT register STREQUAL "")
        string(REPLACE "${PATTERN}(" "" register "${register}")
        string(REPLACE "," "" register "${register}")
        # [ \t\r\n]+ is used for blank characters.
        # Here we use '+' instead of '*' since it is a REPLACE operation.
        string(REGEX REPLACE "[ \t\r\n]+" "" register "${register}")
    endif()
    
    set(${OUTPUT} ${register} PARENT_SCOPE)
endfunction()

W
Wu Yi 已提交
27 28 29 30 31 32
function(op_library TARGET)
    # op_library is a function to create op library. The interface is same as
    # cc_library. But it handle split GPU/CPU code and link some common library
    # for ops.
    set(cc_srcs)
    set(cu_srcs)
33
    set(hip_srcs)
W
Wu Yi 已提交
34
    set(cu_cc_srcs)
35
    set(hip_cc_srcs)
36
    set(xpu_cc_srcs)
L
Liu-xiandong 已提交
37
    set(xpu_kp_cc_srcs)
38
    set(npu_cc_srcs)
F
fwenguang 已提交
39
    set(mlu_cc_srcs)
W
Wu Yi 已提交
40
    set(cudnn_cu_cc_srcs)
41
    set(miopen_cu_cc_srcs)
L
liym27 已提交
42
    set(cudnn_cu_srcs)
43
    set(miopen_cu_srcs)
W
Wu Yi 已提交
44
    set(CUDNN_FILE)
45
    set(MIOPEN_FILE)
W
Wu Yi 已提交
46 47
    set(mkldnn_cc_srcs)
    set(MKLDNN_FILE)
48
    set(op_common_deps operator op_registry math_function layer common_infer_shape_functions)
49 50 51
    if (WITH_ASCEND_CL)
      set(op_common_deps ${op_common_deps} npu_op_runner)
    endif()
F
fwenguang 已提交
52 53 54 55
    if (WITH_MLU)
      set(op_common_deps ${op_common_deps} mlu_baseop)
    endif()

56 57
    # Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build.
    set(options UNITY)
W
Wu Yi 已提交
58 59 60 61 62 63 64 65 66 67 68
    set(oneValueArgs "")
    set(multiValueArgs SRCS DEPS)
    set(pybind_flag 0)
    cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
            "${multiValueArgs}" ${ARGN})

    list(LENGTH op_library_SRCS op_library_SRCS_len)
    if (${op_library_SRCS_len} EQUAL 0)
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
            list(APPEND cc_srcs ${TARGET}.cc)
        endif()
69 70 71 72 73 74 75
        if(WITH_GPU)
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
                list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
            endif()
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
                list(APPEND cu_srcs ${TARGET}.cu)
            endif()
76 77 78
            if (WITH_NV_JETSON)
                list(REMOVE_ITEM cu_srcs "decode_jpeg_op.cu")
            endif()
79 80 81 82 83 84 85 86 87 88 89 90
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
                set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
                        ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
                list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
            endif()
            string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
                list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
            endif()
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu)
                list(APPEND cudnn_cu_srcs ${CUDNN_FILE}.cu)
            endif()
L
liym27 已提交
91
        endif()
92 93 94
        if(WITH_ROCM)
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
                list(APPEND hip_cc_srcs ${TARGET}.cu.cc)
Y
Y_Xuan 已提交
95
            endif()
96 97
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
                list(APPEND hip_srcs ${TARGET}.cu)
Y
Y_Xuan 已提交
98
            endif()
99 100
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
                set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
Y
Y_Xuan 已提交
101
                        ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
102
                list(APPEND hip_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
Y
Y_Xuan 已提交
103
            endif()
104 105 106
            string(REPLACE "_op" "_cudnn_op" MIOPEN_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.cu.cc)
                list(APPEND miopen_cu_cc_srcs ${MIOPEN_FILE}.cu.cc)
Y
Y_Xuan 已提交
107
            endif()
108 109
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.cu)
                list(APPEND miopen_cu_srcs ${MIOPEN_FILE}.cu)
W
Wu Yi 已提交
110 111 112 113
            endif()
        endif()
        if(WITH_MKLDNN)
            string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}")
114 115
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mkldnn/${MKLDNN_FILE}.cc)
                list(APPEND mkldnn_cc_srcs mkldnn/${MKLDNN_FILE}.cc)
W
Wu Yi 已提交
116 117
            endif()
        endif()
118
        if(WITH_XPU)
119 120 121
            string(REPLACE "_op" "_op_xpu" XPU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${XPU_FILE}.cc)
                list(APPEND xpu_cc_srcs ${XPU_FILE}.cc)
122 123
            endif()
        endif()
L
Liu-xiandong 已提交
124 125 126 127 128
        if(WITH_XPU_KP)
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.xpu)
                list(APPEND xpu_kp_cc_srcs ${TARGET}.xpu)
            endif()
        endif()
129 130 131 132 133 134
        if(WITH_ASCEND_CL)
            string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${NPU_FILE}.cc)
                list(APPEND npu_cc_srcs ${NPU_FILE}.cc)
            endif()
        endif()
F
fwenguang 已提交
135 136 137 138 139 140
        if(WITH_MLU)
            string(REPLACE "_op" "_op_mlu" MLU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MLU_FILE}.cc)
                list(APPEND mlu_cc_srcs ${MLU_FILE}.cc)
            endif()
        endif()
W
Wu Yi 已提交
141 142
    else()
        foreach(src ${op_library_SRCS})
143 144 145 146 147 148 149 150
            if(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu$")
                list(APPEND miopen_cu_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu$")
                list(APPEND hip_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu.cc$")
                list(APPEND miopen_cu_cc_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu.cc$")
                list(APPEND hip_cc_srcs ${src})
151
            elseif(WITH_GPU AND ${src} MATCHES ".*_cudnn_op.cu$")
L
liym27 已提交
152
                list(APPEND cudnn_cu_srcs ${src})
153
            elseif (WITH_GPU AND ${src} MATCHES ".*\\.cu$")
W
Wu Yi 已提交
154
                list(APPEND cu_srcs ${src})
155
            elseif(WITH_GPU AND ${src} MATCHES ".*_cudnn_op.cu.cc$")
W
Wu Yi 已提交
156
                list(APPEND cudnn_cu_cc_srcs ${src})
157 158
            elseif(WITH_GPU AND ${src} MATCHES ".*\\.cu.cc$")
                list(APPEND cu_cc_srcs ${src})
W
Wu Yi 已提交
159 160
            elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
                list(APPEND mkldnn_cc_srcs ${src})
161
            elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$")
162
                list(APPEND xpu_cc_srcs ${src})
L
Liu-xiandong 已提交
163 164
            elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.xpu$")
                list(APPEND xpu_kp_cc_srcs ${src})
165 166
            elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
                list(APPEND npu_cc_srcs ${src})
F
fwenguang 已提交
167 168
            elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$")
                list(APPEND mlu_cc_srcs ${src})
W
Wu Yi 已提交
169 170 171
            elseif(${src} MATCHES ".*\\.cc$")
                list(APPEND cc_srcs ${src})
            else()
L
Liu-xiandong 已提交
172
                message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu or .xpu")
W
Wu Yi 已提交
173 174 175
            endif()
        endforeach()
    endif()
L
Liu-xiandong 已提交
176 177 178
    
    list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
    list(LENGTH xpu_kp_cc_srcs xpu_kp_cc_srcs_len)
W
Wu Yi 已提交
179 180 181 182 183 184
    list(LENGTH cc_srcs cc_srcs_len)
    if (${cc_srcs_len} EQUAL 0)
        message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
    endif()
    if (WIN32)
    # remove windows unsupported op, because windows has no nccl, no warpctc such ops.
P
peizhilin 已提交
185
    foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op")
W
Wu Yi 已提交
186 187 188 189 190
        if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
          return()
        endif()
    endforeach()
    endif(WIN32)
191 192 193 194 195 196 197 198 199 200 201 202 203

    # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
    if(WITH_UNITY_BUILD AND op_library_UNITY)
        # Generate the unity target name by the directory where source files located.
        string(REPLACE "${PADDLE_SOURCE_DIR}/paddle/fluid/" "" UNITY_TARGET ${CMAKE_CURRENT_SOURCE_DIR})
        string(REPLACE "/" "_" UNITY_TARGET ${UNITY_TARGET})
        set(UNITY_TARGET "paddle_${UNITY_TARGET}_unity")
        if(NOT ${UNITY_TARGET} IN_LIST OP_LIBRARY)
            set(OP_LIBRARY ${UNITY_TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
        endif()
    else()
        set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
    endif()
W
Wu Yi 已提交
204 205 206 207 208 209

    list(LENGTH op_library_DEPS op_library_DEPS_len)
    if (${op_library_DEPS_len} GREATER 0)
        set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
    endif()
    if (WITH_GPU)
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
        if(WITH_UNITY_BUILD AND op_library_UNITY)
            # Combine the cc and cu source files.
            compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs})
            compose_unity_target_sources(${UNITY_TARGET} cu ${cudnn_cu_srcs} ${cu_srcs})
            if(TARGET ${UNITY_TARGET})
                # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
                target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources} ${unity_target_cu_sources})
            else()
                # If `UNITY_TARGET` does not exist, create `UNITY_TARGET` with source files.
                nv_library(${UNITY_TARGET} SRCS ${unity_target_cc_sources} ${unity_target_cu_sources} DEPS ${op_library_DEPS} ${op_common_deps})
            endif()
            # Add alias library to handle dependencies.
            add_library(${TARGET} ALIAS ${UNITY_TARGET})
        else()
            nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cudnn_cu_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
W
Wu Yi 已提交
226
                ${op_common_deps})
227
        endif()
228 229 230 231
    elseif (WITH_ROCM)
        list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
        list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
        list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
Z
zhiboniu 已提交
232
        list(REMOVE_ITEM hip_srcs "cholesky_solve_op.cu")
Z
zhiboniu 已提交
233
        list(REMOVE_ITEM hip_srcs "lu_op.cu")
234
        list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
235
        list(REMOVE_ITEM hip_srcs "svd_op.cu")
236
        list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu")
237
        list(REMOVE_ITEM hip_srcs "qr_op.cu")
238
        list(REMOVE_ITEM hip_srcs "eigh_op.cu")
239
        list(REMOVE_ITEM hip_srcs "lstsq_op.cu")
240
        list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
241
        list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
242
        hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
W
Wu Yi 已提交
243
                ${op_common_deps})
L
Liu-xiandong 已提交
244 245
    elseif (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0)
        xpu_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${xpu_kp_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps})
W
Wu Yi 已提交
246
    else()
247 248 249
        # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
        if(WITH_UNITY_BUILD AND op_library_UNITY)
            # Combine the cc source files.
F
fwenguang 已提交
250
            compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs})
251 252 253 254 255 256 257 258 259 260
            if(TARGET ${UNITY_TARGET})
                # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
                target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources})
            else()
                # If `UNITY_TARGET` does not exist, create `UNITY_TARGET` with source files.
                cc_library(${UNITY_TARGET} SRCS ${unity_target_cc_sources} DEPS ${op_library_DEPS} ${op_common_deps})
            endif()
            # Add alias library to handle dependencies.
            add_library(${TARGET} ALIAS ${UNITY_TARGET})
        else()
F
fwenguang 已提交
261
            cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs} DEPS ${op_library_DEPS}
262 263
                ${op_common_deps})
        endif()
W
Wu Yi 已提交
264 265
    endif()

266 267 268 269 270 271 272 273 274 275
    list(LENGTH cu_srcs cu_srcs_len)
    list(LENGTH hip_srcs hip_srcs_len)
    list(LENGTH cu_cc_srcs cu_cc_srcs_len)
    list(LENGTH hip_cc_srcs hip_cc_srcs_len)
    list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
    list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
    list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
    list(LENGTH npu_cc_srcs npu_cc_srcs_len)
    list(LENGTH mlu_cc_srcs mlu_cc_srcs_len)

W
Wu Yi 已提交
276
    # Define operators that don't need pybind here.
277
    foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op"
278 279 280 281 282 283
    "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op")
    
            if ("${TARGET}" STREQUAL "${manual_pybind_op}")
                set(pybind_flag 1)
            endif()
        endforeach()
W
Wu Yi 已提交
284 285 286 287

    # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
    # Note that it's enough to just adding one operator to pybind in a *_op.cc file.
    # And for detail pybind information, please see generated paddle/pybind/pybind.h.
288
    set(ORIGINAL_TARGET ${TARGET})
289
    string(REGEX REPLACE "_op" "" TARGET "${TARGET}")
290

291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
    foreach(cc_src ${cc_srcs})
        # pybind USE_OP_ITSELF
        set(op_name "")
        find_register(${cc_src} "REGISTER_OPERATOR" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
            # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn
            set(TARGET ${op_name})  
            set(pybind_flag 1)
        endif()
        
        set(op_name "")
        find_register(${cc_src} "REGISTER_OP_WITHOUT_GRADIENT" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
            # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn
            set(TARGET ${op_name})  
            set(pybind_flag 1)
        endif()        
W
Wu Yi 已提交
310

311 312 313 314 315 316 317 318 319 320 321 322 323 324
        # pybind USE_OP_DEVICE_KERNEL for CPU
        set(op_name "")
        find_register(${cc_src} "REGISTER_OP_CPU_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CPU);\n")
            # why change TARGET here?
            # when building padle with on_infer, the REGISTER_OPERATOR(*_grad) will be removed before compiling (see details in remove_grad_op_and_kernel.py)
            # in elementwise_op.cc, it will find REGISTER_OPERATOR(grad_add) and set TARGET to grad_add
            # and, in the following "mkldnn" part, it will add USE_OP_DEVICE_KERNEL(grad_add, MKLDNN) to pybind.h
            # however, grad_add has no mkldnn kernel. 
            set(TARGET ${op_name})  
            set(pybind_flag 1)
        endif()
    endforeach()
W
Wu Yi 已提交
325

326 327 328 329 330 331 332 333 334 335 336
    # pybind USE_OP_DEVICE_KERNEL for CUDA
    list (APPEND cu_srcs ${cu_cc_srcs})
    # message("cu_srcs ${cu_srcs}")
    foreach(cu_src ${cu_srcs})
        set(op_name "")
        find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
            set(pybind_flag 1)
        endif()
    endforeach()
W
Wu Yi 已提交
337

338 339 340 341 342 343 344 345 346 347 348
    # pybind USE_OP_DEVICE_KERNEL for ROCm
    list (APPEND hip_srcs ${hip_cc_srcs})
    # message("hip_srcs ${hip_srcs}")
    foreach(hip_src ${hip_srcs})
        set(op_name "")
        find_register(${hip_src} "REGISTER_OP_CUDA_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
            set(pybind_flag 1)
        endif()
    endforeach()
W
Wu Yi 已提交
349

350 351 352 353 354 355 356
    # pybind USE_OP_DEVICE_KERNEL for CUDNN/MIOPEN
    list(APPEND cudnn_cu_srcs ${cudnn_cu_cc_srcs}) 
    list(APPEND cudnn_cu_srcs ${miopen_cu_cc_srcs}) 
    list(APPEND cudnn_cu_srcs ${miopen_cu_srcs})   
    list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len) 
    #message("cudnn_cu_srcs ${cudnn_cu_srcs}")
    if(${cudnn_cu_srcs_len} GREATER 0 AND ${ORIGINAL_TARGET} STREQUAL "activation_op")
Y
Y_Xuan 已提交
357
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n")
358 359 360 361 362 363 364 365 366
    else()
    foreach(cudnn_src ${cudnn_cu_srcs})
        set(op_name "")
        find_register(${cudnn_src} "REGISTER_OP_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDNN);\n")
            set(pybind_flag 1)
        endif()
    endforeach()
Y
Y_Xuan 已提交
367 368
    endif()

369

370 371 372 373 374 375 376 377 378 379 380 381
    if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0)
    if(${ORIGINAL_TARGET} STREQUAL "activation_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, XPU);\n")
    else()
        foreach(xpu_src ${xpu_cc_srcs})
        set(op_name "")
        find_register(${xpu_src} "REGISTER_OP_XPU_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n")
            set(pybind_flag 1)
        endif()
        endforeach()
W
Wu Yi 已提交
382
    endif()
383
    endif()
384

L
Liu-xiandong 已提交
385 386 387 388 389
    # pybind USE_OP_DEVICE_KERNEL for XPU KP
    if (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0)
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, KP);\n")
    endif()

390
    # pybind USE_OP_DEVICE_KERNEL for NPU
391
    if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0)
392 393 394 395 396 397
        foreach(npu_src ${npu_cc_srcs})
        set(op_name "")
        find_register(${npu_src} "REGISTER_OP_NPU_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, NPU);\n")
            set(pybind_flag 1)
398
        endif()
399
        endforeach()
400
    endif()
F
fwenguang 已提交
401

402 403 404 405 406 407 408 409
    # pybind USE_OP_DEVICE_KERNEL for MLU
    if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0)
        foreach(mlu_src ${mlu_cc_srcs})
        set(op_name "")
        find_register(${mlu_src} "REGISTER_OP_MLU_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MLU);\n")
            set(pybind_flag 1)
F
fwenguang 已提交
410
        endif()
411
        endforeach()
F
fwenguang 已提交
412
    endif()
413

W
Wu Yi 已提交
414 415 416 417 418
    # pybind USE_OP_DEVICE_KERNEL for MKLDNN
    if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
      # Append first implemented MKLDNN activation operator
      if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
X
Xin Pan 已提交
419 420
      elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
421 422
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
M
Michał Gallus 已提交
423 424 425
      elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n")
426
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n")
M
Michał Gallus 已提交
427 428 429 430
      elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n")
W
Wu Yi 已提交
431
      else()
432 433 434 435 436 437 438 439
        foreach(mkldnn_src ${mkldnn_cc_srcs})
        set(op_name "")
        find_register(${mkldnn_src} "REGISTER_OP_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MKLDNN);\n")
            set(pybind_flag 1)
        endif()
        endforeach()        
W
Wu Yi 已提交
440 441 442
      endif()
    endif()

443 444 445 446 447 448 449 450 451
    # pybind USE_NO_KERNEL_OP
    # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
    string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
    string(REPLACE "_op" "" TARGET "${TARGET}")
    if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
        file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
        set(pybind_flag 1)
    endif()

W
Wu Yi 已提交
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
    # pybind USE_OP
    if (${pybind_flag} EQUAL 0)
      # NOTE(*): activation use macro to regist the kernels, set use_op manually.
      if(${TARGET} STREQUAL "activation")
        file(APPEND ${pybind_file} "USE_OP(relu);\n")
      elseif(${TARGET} STREQUAL "fake_dequantize")
        file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
      elseif(${TARGET} STREQUAL "fake_quantize")
        file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
      elseif(${TARGET} STREQUAL "tensorrt_engine_op")
          message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
      else()
        file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
      endif()
    endif()
endfunction()

function(register_operators)
    set(options "")
    set(oneValueArgs "")
W
Wu Yi 已提交
472
    set(multiValueArgs EXCLUDES DEPS)
W
Wu Yi 已提交
473 474 475 476
    cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}"
            "${multiValueArgs}" ${ARGN})
    file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
    string(REPLACE "_mkldnn" "" OPS "${OPS}")
477
    string(REPLACE "_xpu" "" OPS "${OPS}")
478
    string(REPLACE "_npu" "" OPS "${OPS}")
F
fwenguang 已提交
479
    string(REPLACE "_mlu" "" OPS "${OPS}")
W
Wu Yi 已提交
480 481
    string(REPLACE ".cc" "" OPS "${OPS}")
    list(REMOVE_DUPLICATES OPS)
W
Wu Yi 已提交
482
    list(LENGTH register_operators_DEPS register_operators_DEPS_len)
W
Wu Yi 已提交
483 484 485 486

    foreach(src ${OPS})
        list(FIND register_operators_EXCLUDES ${src} _index)
        if (${_index} EQUAL -1)
W
Wu Yi 已提交
487
            if (${register_operators_DEPS_len} GREATER 0)
488
                op_library(${src} UNITY DEPS ${register_operators_DEPS})
W
Wu Yi 已提交
489
            else()
490
                op_library(${src} UNITY)
W
Wu Yi 已提交
491
            endif()
W
Wu Yi 已提交
492 493
        endif()
    endforeach()
494 495 496 497 498 499 500 501

    # Complete the processing of `UNITY_TARGET`.
    if(WITH_UNITY_BUILD)
        finish_unity_target(cc)
        if(WITH_GPU)
            finish_unity_target(cu)
        endif()
    endif()
W
Wu Yi 已提交
502
endfunction()