From 9fb68c7ccafeaa2ceb2d1519cefccc78033d0ee7 Mon Sep 17 00:00:00 2001 From: xiaogang Date: Fri, 3 Apr 2020 11:20:54 +0800 Subject: [PATCH] feat: add train flag for add_kernel and add_operator to control grad_ops (#3316) --- cmake/lite.cmake | 7 ++++++- lite/kernels/arm/CMakeLists.txt | 13 ++++++------- lite/operators/CMakeLists.txt | 13 ++++++------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/cmake/lite.cmake b/cmake/lite.cmake index 265de3fbf6..c7b65d76f9 100644 --- a/cmake/lite.cmake +++ b/cmake/lite.cmake @@ -307,6 +307,9 @@ function(add_kernel TARGET device level) if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) return() endif() + if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN)) + return() + endif() if ("${device}" STREQUAL "Host") @@ -434,11 +437,13 @@ function(add_operator TARGET level) ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) return() endif() + if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN)) + return() + endif() foreach(src ${args_SRCS}) if(LITE_BUILD_TAILOR) diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 7550d77014..a3b1c3680e 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -106,13 +106,12 @@ add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math # 4. training kernels add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm) -if(LITE_WITH_TRAIN) - add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(elementwise_grad_compute_arm ARM basic SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) -endif() + +add_kernel(mean_grad_compute_arm ARM train SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(activation_grad_compute_arm ARM train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(elementwise_grad_compute_arm ARM train SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(mul_grad_compute_arm ARM train SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sgd_compute_arm ARM train SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 48e2756031..ae9ec3ad47 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -141,13 +141,12 @@ add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS}) # 4. training op add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) -if (LITE_WITH_TRAIN) - add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) - add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) - add_operator(elementwise_grad_op extra SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) - add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS}) - add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS}) -endif() + +add_operator(mean_grad_op train SRCS mean_grad_op.cc DEPS ${op_DEPS}) +add_operator(activation_grad_ops train SRCS activation_grad_ops.cc DEPS ${op_DEPS}) +add_operator(elementwise_grad_op train SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) +add_operator(mul_grad_op train SRCS mul_grad_op.cc DEPS ${op_DEPS}) +add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS}) if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc -- GitLab