CMakeLists.txt 2.4 KB
Newer Older
L
Luo Tao 已提交
1 2
file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}")
Y
Yu Yang 已提交
3 4 5 6
function(op_library TARGET)
    # op_library is a function to create op library. The interface is same as
    # cc_library. But it handle split GPU/CPU code and link some common library
    # for ops.
L
Luo Tao 已提交
7
    set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE)
Y
Yu Yang 已提交
8 9
    set(cc_srcs)
    set(cu_srcs)
Y
Yu Yang 已提交
10
    set(op_common_deps operator op_registry)
Y
Yu Yang 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    set(options "")
    set(oneValueArgs "")
    set(multiValueArgs SRCS DEPS)
    cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
            "${multiValueArgs}" ${ARGN})

    foreach(src ${op_library_SRCS})
        if (${src} MATCHES ".*\\.cu$")
            list(APPEND cu_srcs ${src})
        elseif(${src} MATCHES ".*\\.cc$")
            list(APPEND cc_srcs ${src})
        else()
            message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu")
        endif()
    endforeach()

    list(LENGTH cc_srcs cc_srcs_len)
    if (${cc_srcs_len} EQUAL 0)
        message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
    endif()

    list(LENGTH cu_srcs cu_srcs_len)
Y
Yu Yang 已提交
33 34
    list(LENGTH op_library_DEPS dep_len)
    if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0)
Y
Yu Yang 已提交
35 36 37 38 39 40 41 42 43 44 45 46
        message(WARNING "The op library ${TARGET} not support GPU!")
    endif()

    if (WITH_GPU)
        nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
                ${op_common_deps})
    else()
        cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
                ${op_common_deps})
    endif()
endfunction()

Q
qijun 已提交
47
add_subdirectory(math)
48

L
Luo Tao 已提交
49 50 51 52 53 54
list(REMOVE_ITEM GENERAL_OPS
     net_op
     minus_op
     mul_op
     recurrent_op
     scale_op)
L
liaogang 已提交
55

56 57
op_library(net_op SRCS net_op.cc)
op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
Q
qijun 已提交
58
op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
L
Luo Tao 已提交
59 60
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc 
  DEPS framework_proto tensor operator net_op)
Y
Yu Yang 已提交
61
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
62

L
Luo Tao 已提交
63 64 65 66 67 68
foreach(src ${GENERAL_OPS})
    op_library(${src} SRCS ${src}.cc ${src}.cu)
endforeach()

set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")

69 70 71
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)