diff --git a/cmake/operators.cmake b/cmake/operators.cmake index e58dbf77b4c9c83bd131404d61181adf89fec305..8469dc4c02ee37b333254d6d35b0eb48354d4b86 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -335,6 +335,17 @@ function(op_library TARGET) endif() endforeach() + # pybind USE_OP_DEVICE_KERNEL for ROCm + list (APPEND hip_srcs ${hip_cc_srcs}) + # message("hip_srcs ${hip_srcs}") + foreach(hip_src ${hip_srcs}) + set(op_name "") + find_register(${hip_src} "REGISTER_OP_CUDA_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n") + set(pybind_flag 1) + endif() + endforeach() # pybind USE_OP_DEVICE_KERNEL for CUDNN/MIOPEN list(APPEND cudnn_cu_srcs ${cudnn_cu_cc_srcs})