diff --git a/cmake/lite.cmake b/cmake/lite.cmake index 265de3fbf68542f1b1525257887cbfaa4d1c4d62..c7b65d76f9631e39e9b97df18a396f5e0aec8a63 100644 --- a/cmake/lite.cmake +++ b/cmake/lite.cmake @@ -307,6 +307,9 @@ function(add_kernel TARGET device level) if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) return() endif() + if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN)) + return() + endif() if ("${device}" STREQUAL "Host") @@ -434,11 +437,13 @@ function(add_operator TARGET level) ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) return() endif() + if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN)) + return() + endif() foreach(src ${args_SRCS}) if(LITE_BUILD_TAILOR) diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 7550d770145d92ebd343f96a82c6f34d72c91ea5..a3b1c3680e283a4425fe22209c443ce7cd958267 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -106,13 +106,12 @@ add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math # 4. training kernels add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm) -if(LITE_WITH_TRAIN) - add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(elementwise_grad_compute_arm ARM basic SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) -endif() + +add_kernel(mean_grad_compute_arm ARM train SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(activation_grad_compute_arm ARM train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(elementwise_grad_compute_arm ARM train SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(mul_grad_compute_arm ARM train SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sgd_compute_arm ARM train SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 48e27560317c089446e8dbc5040786f34ca962c4..ae9ec3ad47fbc00c91ba06c1597bd65e510b629b 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -141,13 +141,12 @@ add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS}) # 4. training op add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) -if (LITE_WITH_TRAIN) - add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) - add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) - add_operator(elementwise_grad_op extra SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) - add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS}) - add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS}) -endif() + +add_operator(mean_grad_op train SRCS mean_grad_op.cc DEPS ${op_DEPS}) +add_operator(activation_grad_ops train SRCS activation_grad_ops.cc DEPS ${op_DEPS}) +add_operator(elementwise_grad_op train SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) +add_operator(mul_grad_op train SRCS mul_grad_op.cc DEPS ${op_DEPS}) +add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS}) if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc