operators.cmake 20.9 KB
Newer Older
1 2
# CMake file `unity_build` is used to handle Unity Build compilation.
include(unity_build)
W
Wu Yi 已提交
3
set(PART_CUDA_KERNEL_FILES)
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

function(find_register FILENAME PATTERN OUTPUT)
# find the op_name of REGISTER_OPERATOR(op_name, ...), REGISTER_OP_CPU_KERNEL(op_name, ...) , etc.
# set op_name to OUTPUT
    set(options "")
    set(oneValueArgs "")
    set(multiValueArgs "")
    file(READ ${FILENAME} CONTENT)
    # message ("number of arguments sent to function: ${ARGC}")
    # message ("all function arguments:               ${ARGV}")
    # message("PATTERN ${PATTERN}")
    string(REGEX MATCH "${PATTERN}\\([ \t\r\n]*[a-z0-9_]*," register "${CONTENT}")
    if (NOT register STREQUAL "")
        string(REPLACE "${PATTERN}(" "" register "${register}")
        string(REPLACE "," "" register "${register}")
        # [ \t\r\n]+ is used for blank characters.
        # Here we use '+' instead of '*' since it is a REPLACE operation.
        string(REGEX REPLACE "[ \t\r\n]+" "" register "${register}")
    endif()
    
    set(${OUTPUT} ${register} PARENT_SCOPE)
endfunction()

W
Wu Yi 已提交
27 28 29 30 31 32
function(op_library TARGET)
    # op_library is a function to create op library. The interface is same as
    # cc_library. But it handle split GPU/CPU code and link some common library
    # for ops.
    set(cc_srcs)
    set(cu_srcs)
33
    set(hip_srcs)
W
Wu Yi 已提交
34
    set(cu_cc_srcs)
35
    set(hip_cc_srcs)
36
    set(xpu_cc_srcs)
37
    set(npu_cc_srcs)
F
fwenguang 已提交
38
    set(mlu_cc_srcs)
W
Wu Yi 已提交
39
    set(cudnn_cu_cc_srcs)
40
    set(miopen_cu_cc_srcs)
L
liym27 已提交
41
    set(cudnn_cu_srcs)
42
    set(miopen_cu_srcs)
W
Wu Yi 已提交
43
    set(CUDNN_FILE)
44
    set(MIOPEN_FILE)
W
Wu Yi 已提交
45 46
    set(mkldnn_cc_srcs)
    set(MKLDNN_FILE)
47
    set(op_common_deps operator op_registry math_function layer common_infer_shape_functions)
48 49 50
    if (WITH_ASCEND_CL)
      set(op_common_deps ${op_common_deps} npu_op_runner)
    endif()
F
fwenguang 已提交
51 52 53 54
    if (WITH_MLU)
      set(op_common_deps ${op_common_deps} mlu_baseop)
    endif()

55 56
    # Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build.
    set(options UNITY)
W
Wu Yi 已提交
57 58 59 60 61 62 63 64 65 66 67
    set(oneValueArgs "")
    set(multiValueArgs SRCS DEPS)
    set(pybind_flag 0)
    cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
            "${multiValueArgs}" ${ARGN})

    list(LENGTH op_library_SRCS op_library_SRCS_len)
    if (${op_library_SRCS_len} EQUAL 0)
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
            list(APPEND cc_srcs ${TARGET}.cc)
        endif()
68 69 70 71 72 73 74
        if(WITH_GPU)
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
                list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
            endif()
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
                list(APPEND cu_srcs ${TARGET}.cu)
            endif()
75 76 77
            if (WITH_NV_JETSON)
                list(REMOVE_ITEM cu_srcs "decode_jpeg_op.cu")
            endif()
78 79 80 81 82 83 84 85 86 87 88 89
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
                set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
                        ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
                list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
            endif()
            string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
                list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
            endif()
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu)
                list(APPEND cudnn_cu_srcs ${CUDNN_FILE}.cu)
            endif()
L
liym27 已提交
90
        endif()
91 92 93
        if(WITH_ROCM)
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
                list(APPEND hip_cc_srcs ${TARGET}.cu.cc)
Y
Y_Xuan 已提交
94
            endif()
95 96
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
                list(APPEND hip_srcs ${TARGET}.cu)
Y
Y_Xuan 已提交
97
            endif()
98 99
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
                set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
Y
Y_Xuan 已提交
100
                        ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
101
                list(APPEND hip_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
Y
Y_Xuan 已提交
102
            endif()
103 104 105
            string(REPLACE "_op" "_cudnn_op" MIOPEN_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.cu.cc)
                list(APPEND miopen_cu_cc_srcs ${MIOPEN_FILE}.cu.cc)
Y
Y_Xuan 已提交
106
            endif()
107 108
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.cu)
                list(APPEND miopen_cu_srcs ${MIOPEN_FILE}.cu)
W
Wu Yi 已提交
109 110 111 112
            endif()
        endif()
        if(WITH_MKLDNN)
            string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}")
113 114
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mkldnn/${MKLDNN_FILE}.cc)
                list(APPEND mkldnn_cc_srcs mkldnn/${MKLDNN_FILE}.cc)
W
Wu Yi 已提交
115 116
            endif()
        endif()
117
        if(WITH_XPU)
118 119 120
            string(REPLACE "_op" "_op_xpu" XPU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${XPU_FILE}.cc)
                list(APPEND xpu_cc_srcs ${XPU_FILE}.cc)
121 122
            endif()
        endif()
123 124 125 126 127 128
        if(WITH_ASCEND_CL)
            string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${NPU_FILE}.cc)
                list(APPEND npu_cc_srcs ${NPU_FILE}.cc)
            endif()
        endif()
F
fwenguang 已提交
129 130 131 132 133 134
        if(WITH_MLU)
            string(REPLACE "_op" "_op_mlu" MLU_FILE "${TARGET}")
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MLU_FILE}.cc)
                list(APPEND mlu_cc_srcs ${MLU_FILE}.cc)
            endif()
        endif()
W
Wu Yi 已提交
135 136
    else()
        foreach(src ${op_library_SRCS})
137 138 139 140 141 142 143 144
            if(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu$")
                list(APPEND miopen_cu_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu$")
                list(APPEND hip_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu.cc$")
                list(APPEND miopen_cu_cc_srcs ${src})
            elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu.cc$")
                list(APPEND hip_cc_srcs ${src})
145
            elseif(WITH_GPU AND ${src} MATCHES ".*_cudnn_op.cu$")
L
liym27 已提交
146
                list(APPEND cudnn_cu_srcs ${src})
147
            elseif (WITH_GPU AND ${src} MATCHES ".*\\.cu$")
W
Wu Yi 已提交
148
                list(APPEND cu_srcs ${src})
149
            elseif(WITH_GPU AND ${src} MATCHES ".*_cudnn_op.cu.cc$")
W
Wu Yi 已提交
150
                list(APPEND cudnn_cu_cc_srcs ${src})
151 152
            elseif(WITH_GPU AND ${src} MATCHES ".*\\.cu.cc$")
                list(APPEND cu_cc_srcs ${src})
W
Wu Yi 已提交
153 154
            elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
                list(APPEND mkldnn_cc_srcs ${src})
155
            elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$")
156
                list(APPEND xpu_cc_srcs ${src})
157 158
            elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
                list(APPEND npu_cc_srcs ${src})
F
fwenguang 已提交
159 160
            elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$")
                list(APPEND mlu_cc_srcs ${src})
W
Wu Yi 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174
            elseif(${src} MATCHES ".*\\.cc$")
                list(APPEND cc_srcs ${src})
            else()
                message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu")
            endif()
        endforeach()
    endif()

    list(LENGTH cc_srcs cc_srcs_len)
    if (${cc_srcs_len} EQUAL 0)
        message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
    endif()
    if (WIN32)
    # remove windows unsupported op, because windows has no nccl, no warpctc such ops.
P
peizhilin 已提交
175
    foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op")
W
Wu Yi 已提交
176 177 178 179 180
        if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
          return()
        endif()
    endforeach()
    endif(WIN32)
181 182 183 184 185 186 187 188 189 190 191 192 193

    # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
    if(WITH_UNITY_BUILD AND op_library_UNITY)
        # Generate the unity target name by the directory where source files located.
        string(REPLACE "${PADDLE_SOURCE_DIR}/paddle/fluid/" "" UNITY_TARGET ${CMAKE_CURRENT_SOURCE_DIR})
        string(REPLACE "/" "_" UNITY_TARGET ${UNITY_TARGET})
        set(UNITY_TARGET "paddle_${UNITY_TARGET}_unity")
        if(NOT ${UNITY_TARGET} IN_LIST OP_LIBRARY)
            set(OP_LIBRARY ${UNITY_TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
        endif()
    else()
        set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
    endif()
W
Wu Yi 已提交
194 195 196 197 198 199

    list(LENGTH op_library_DEPS op_library_DEPS_len)
    if (${op_library_DEPS_len} GREATER 0)
        set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
    endif()
    if (WITH_GPU)
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
        if(WITH_UNITY_BUILD AND op_library_UNITY)
            # Combine the cc and cu source files.
            compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs})
            compose_unity_target_sources(${UNITY_TARGET} cu ${cudnn_cu_srcs} ${cu_srcs})
            if(TARGET ${UNITY_TARGET})
                # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
                target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources} ${unity_target_cu_sources})
            else()
                # If `UNITY_TARGET` does not exist, create `UNITY_TARGET` with source files.
                nv_library(${UNITY_TARGET} SRCS ${unity_target_cc_sources} ${unity_target_cu_sources} DEPS ${op_library_DEPS} ${op_common_deps})
            endif()
            # Add alias library to handle dependencies.
            add_library(${TARGET} ALIAS ${UNITY_TARGET})
        else()
            nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cudnn_cu_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
W
Wu Yi 已提交
216
                ${op_common_deps})
217
        endif()
218 219 220 221
    elseif (WITH_ROCM)
        list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
        list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
        list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
Z
zhiboniu 已提交
222
        list(REMOVE_ITEM hip_srcs "cholesky_solve_op.cu")
Z
zhiboniu 已提交
223
        list(REMOVE_ITEM hip_srcs "lu_op.cu")
224
        list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
225
        list(REMOVE_ITEM hip_srcs "svd_op.cu")
226
        list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu")
227
        list(REMOVE_ITEM hip_srcs "qr_op.cu")
228
        list(REMOVE_ITEM hip_srcs "eigh_op.cu")
229
        list(REMOVE_ITEM hip_srcs "lstsq_op.cu")
230
        list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
231
        list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
232
        hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
W
Wu Yi 已提交
233 234
                ${op_common_deps})
    else()
235 236 237
        # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
        if(WITH_UNITY_BUILD AND op_library_UNITY)
            # Combine the cc source files.
F
fwenguang 已提交
238
            compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs})
239 240 241 242 243 244 245 246 247 248
            if(TARGET ${UNITY_TARGET})
                # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
                target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources})
            else()
                # If `UNITY_TARGET` does not exist, create `UNITY_TARGET` with source files.
                cc_library(${UNITY_TARGET} SRCS ${unity_target_cc_sources} DEPS ${op_library_DEPS} ${op_common_deps})
            endif()
            # Add alias library to handle dependencies.
            add_library(${TARGET} ALIAS ${UNITY_TARGET})
        else()
F
fwenguang 已提交
249
            cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs} DEPS ${op_library_DEPS}
250 251
                ${op_common_deps})
        endif()
W
Wu Yi 已提交
252 253
    endif()

254 255 256 257 258 259 260 261 262 263
    list(LENGTH cu_srcs cu_srcs_len)
    list(LENGTH hip_srcs hip_srcs_len)
    list(LENGTH cu_cc_srcs cu_cc_srcs_len)
    list(LENGTH hip_cc_srcs hip_cc_srcs_len)
    list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
    list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
    list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
    list(LENGTH npu_cc_srcs npu_cc_srcs_len)
    list(LENGTH mlu_cc_srcs mlu_cc_srcs_len)

W
Wu Yi 已提交
264
    # Define operators that don't need pybind here.
265
    foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op"
266 267 268 269 270 271
    "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op")
    
            if ("${TARGET}" STREQUAL "${manual_pybind_op}")
                set(pybind_flag 1)
            endif()
        endforeach()
W
Wu Yi 已提交
272 273 274 275

    # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
    # Note that it's enough to just adding one operator to pybind in a *_op.cc file.
    # And for detail pybind information, please see generated paddle/pybind/pybind.h.
276
    set(ORIGINAL_TARGET ${TARGET})
277
    string(REGEX REPLACE "_op" "" TARGET "${TARGET}")
278

279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
    foreach(cc_src ${cc_srcs})
        # pybind USE_OP_ITSELF
        set(op_name "")
        find_register(${cc_src} "REGISTER_OPERATOR" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
            # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn
            set(TARGET ${op_name})  
            set(pybind_flag 1)
        endif()
        
        set(op_name "")
        find_register(${cc_src} "REGISTER_OP_WITHOUT_GRADIENT" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
            # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn
            set(TARGET ${op_name})  
            set(pybind_flag 1)
        endif()        
W
Wu Yi 已提交
298

299 300 301 302 303 304 305 306 307 308 309 310 311 312
        # pybind USE_OP_DEVICE_KERNEL for CPU
        set(op_name "")
        find_register(${cc_src} "REGISTER_OP_CPU_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CPU);\n")
            # why change TARGET here?
            # when building padle with on_infer, the REGISTER_OPERATOR(*_grad) will be removed before compiling (see details in remove_grad_op_and_kernel.py)
            # in elementwise_op.cc, it will find REGISTER_OPERATOR(grad_add) and set TARGET to grad_add
            # and, in the following "mkldnn" part, it will add USE_OP_DEVICE_KERNEL(grad_add, MKLDNN) to pybind.h
            # however, grad_add has no mkldnn kernel. 
            set(TARGET ${op_name})  
            set(pybind_flag 1)
        endif()
    endforeach()
W
Wu Yi 已提交
313

314 315 316 317 318 319 320 321 322 323 324
    # pybind USE_OP_DEVICE_KERNEL for CUDA
    list (APPEND cu_srcs ${cu_cc_srcs})
    # message("cu_srcs ${cu_srcs}")
    foreach(cu_src ${cu_srcs})
        set(op_name "")
        find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
            set(pybind_flag 1)
        endif()
    endforeach()
W
Wu Yi 已提交
325 326


327 328 329 330 331 332 333
    # pybind USE_OP_DEVICE_KERNEL for CUDNN/MIOPEN
    list(APPEND cudnn_cu_srcs ${cudnn_cu_cc_srcs}) 
    list(APPEND cudnn_cu_srcs ${miopen_cu_cc_srcs}) 
    list(APPEND cudnn_cu_srcs ${miopen_cu_srcs})   
    list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len) 
    #message("cudnn_cu_srcs ${cudnn_cu_srcs}")
    if(${cudnn_cu_srcs_len} GREATER 0 AND ${ORIGINAL_TARGET} STREQUAL "activation_op")
Y
Y_Xuan 已提交
334
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n")
335 336 337 338 339 340 341 342 343
    else()
    foreach(cudnn_src ${cudnn_cu_srcs})
        set(op_name "")
        find_register(${cudnn_src} "REGISTER_OP_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDNN);\n")
            set(pybind_flag 1)
        endif()
    endforeach()
Y
Y_Xuan 已提交
344 345
    endif()

346

347 348 349 350 351 352 353 354 355 356 357 358
    if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0)
    if(${ORIGINAL_TARGET} STREQUAL "activation_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, XPU);\n")
    else()
        foreach(xpu_src ${xpu_cc_srcs})
        set(op_name "")
        find_register(${xpu_src} "REGISTER_OP_XPU_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n")
            set(pybind_flag 1)
        endif()
        endforeach()
W
Wu Yi 已提交
359
    endif()
360
    endif()
361

362
    # pybind USE_OP_DEVICE_KERNEL for NPU
363
    if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0)
364 365 366 367 368 369
        foreach(npu_src ${npu_cc_srcs})
        set(op_name "")
        find_register(${npu_src} "REGISTER_OP_NPU_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, NPU);\n")
            set(pybind_flag 1)
370
        endif()
371
        endforeach()
372
    endif()
F
fwenguang 已提交
373

374 375 376 377 378 379 380 381
    # pybind USE_OP_DEVICE_KERNEL for MLU
    if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0)
        foreach(mlu_src ${mlu_cc_srcs})
        set(op_name "")
        find_register(${mlu_src} "REGISTER_OP_MLU_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MLU);\n")
            set(pybind_flag 1)
F
fwenguang 已提交
382
        endif()
383
        endforeach()
F
fwenguang 已提交
384
    endif()
385

W
Wu Yi 已提交
386 387 388 389 390
    # pybind USE_OP_DEVICE_KERNEL for MKLDNN
    if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
      # Append first implemented MKLDNN activation operator
      if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
X
Xin Pan 已提交
391 392
      elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
393 394
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
M
Michał Gallus 已提交
395 396 397
      elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n")
398
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n")
M
Michał Gallus 已提交
399 400 401 402
      elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n")
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n")
W
Wu Yi 已提交
403
      else()
404 405 406 407 408 409 410 411
        foreach(mkldnn_src ${mkldnn_cc_srcs})
        set(op_name "")
        find_register(${mkldnn_src} "REGISTER_OP_KERNEL" op_name)
        if(NOT ${op_name} EQUAL "")
            file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MKLDNN);\n")
            set(pybind_flag 1)
        endif()
        endforeach()        
W
Wu Yi 已提交
412 413 414
      endif()
    endif()

415 416 417 418 419 420 421 422 423
    # pybind USE_NO_KERNEL_OP
    # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
    string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
    string(REPLACE "_op" "" TARGET "${TARGET}")
    if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
        file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
        set(pybind_flag 1)
    endif()

W
Wu Yi 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
    # pybind USE_OP
    if (${pybind_flag} EQUAL 0)
      # NOTE(*): activation use macro to regist the kernels, set use_op manually.
      if(${TARGET} STREQUAL "activation")
        file(APPEND ${pybind_file} "USE_OP(relu);\n")
      elseif(${TARGET} STREQUAL "fake_dequantize")
        file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
      elseif(${TARGET} STREQUAL "fake_quantize")
        file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
      elseif(${TARGET} STREQUAL "tensorrt_engine_op")
          message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
      else()
        file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
      endif()
    endif()
endfunction()


function(register_operators)
    set(options "")
    set(oneValueArgs "")
W
Wu Yi 已提交
445
    set(multiValueArgs EXCLUDES DEPS)
W
Wu Yi 已提交
446 447 448 449
    cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}"
            "${multiValueArgs}" ${ARGN})
    file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
    string(REPLACE "_mkldnn" "" OPS "${OPS}")
450
    string(REPLACE "_xpu" "" OPS "${OPS}")
451
    string(REPLACE "_npu" "" OPS "${OPS}")
F
fwenguang 已提交
452
    string(REPLACE "_mlu" "" OPS "${OPS}")
W
Wu Yi 已提交
453 454
    string(REPLACE ".cc" "" OPS "${OPS}")
    list(REMOVE_DUPLICATES OPS)
W
Wu Yi 已提交
455
    list(LENGTH register_operators_DEPS register_operators_DEPS_len)
W
Wu Yi 已提交
456 457 458 459

    foreach(src ${OPS})
        list(FIND register_operators_EXCLUDES ${src} _index)
        if (${_index} EQUAL -1)
W
Wu Yi 已提交
460
            if (${register_operators_DEPS_len} GREATER 0)
461
                op_library(${src} UNITY DEPS ${register_operators_DEPS})
W
Wu Yi 已提交
462
            else()
463
                op_library(${src} UNITY)
W
Wu Yi 已提交
464
            endif()
W
Wu Yi 已提交
465 466
        endif()
    endforeach()
467 468 469 470 471 472 473 474

    # Complete the processing of `UNITY_TARGET`.
    if(WITH_UNITY_BUILD)
        finish_unity_target(cc)
        if(WITH_GPU)
            finish_unity_target(cu)
        endif()
    endif()
W
Wu Yi 已提交
475
endfunction()