CMakeLists.txt 2.6 KB
Newer Older
L
Luo Tao 已提交
1 2
file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}")
Y
Yu Yang 已提交
3 4 5 6
function(op_library TARGET)
    # op_library is a function to create op library. The interface is same as
    # cc_library. But it handle split GPU/CPU code and link some common library
    # for ops.
L
Luo Tao 已提交
7
    set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE)
Y
Yu Yang 已提交
8 9
    set(cc_srcs)
    set(cu_srcs)
Y
Yu Yang 已提交
10
    set(op_common_deps operator op_registry)
Y
Yu Yang 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    set(options "")
    set(oneValueArgs "")
    set(multiValueArgs SRCS DEPS)
    cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
            "${multiValueArgs}" ${ARGN})

    foreach(src ${op_library_SRCS})
        if (${src} MATCHES ".*\\.cu$")
            list(APPEND cu_srcs ${src})
        elseif(${src} MATCHES ".*\\.cc$")
            list(APPEND cc_srcs ${src})
        else()
            message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu")
        endif()
    endforeach()

    list(LENGTH cc_srcs cc_srcs_len)
    if (${cc_srcs_len} EQUAL 0)
        message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
    endif()

    list(LENGTH cu_srcs cu_srcs_len)
Y
Yu Yang 已提交
33 34
    list(LENGTH op_library_DEPS dep_len)
    if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0)
Y
Yu Yang 已提交
35 36 37 38 39 40 41 42 43 44 45 46
        message(WARNING "The op library ${TARGET} not support GPU!")
    endif()

    if (WITH_GPU)
        nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
                ${op_common_deps})
    else()
        cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
                ${op_common_deps})
    endif()
endfunction()

Q
qijun 已提交
47
add_subdirectory(math)
48

49 50 51 52 53 54 55
set(ONLYCPU_OPS
    net_op
    gather_op
    scatter_op)
foreach(src ${ONLYCPU_OPS})
    op_library(${src} SRCS ${src}.cc)
endforeach()
L
liaogang 已提交
56

57 58 59 60 61 62
set(DEPS_OPS
    identity_op
    minus_op
    mul_op
    recurrent_op
    scale_op)
63
op_library(identity_op SRCS identity_op.cc DEPS scale_op)
64
op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
Q
qijun 已提交
65
op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
L
Luo Tao 已提交
66 67
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc 
  DEPS framework_proto tensor operator net_op)
Y
Yu Yang 已提交
68
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
69

70 71 72
list(REMOVE_ITEM GENERAL_OPS
     ${ONLYCPU_OPS}
     ${DEPS_OPS})
L
Luo Tao 已提交
73 74 75 76 77 78
foreach(src ${GENERAL_OPS})
    op_library(${src} SRCS ${src}.cc ${src}.cu)
endforeach()

set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")

79 80 81
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)