未验证 提交 9fb68c7c 编写于 作者: X xiaogang 提交者: GitHub

feat: add train flag for add_kernel and add_operator to control grad_ops (#3316)

上级 3b49256a
...@@ -307,6 +307,9 @@ function(add_kernel TARGET device level) ...@@ -307,6 +307,9 @@ function(add_kernel TARGET device level)
if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA))
return() return()
endif() endif()
if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN))
return()
endif()
if ("${device}" STREQUAL "Host") if ("${device}" STREQUAL "Host")
...@@ -434,11 +437,13 @@ function(add_operator TARGET level) ...@@ -434,11 +437,13 @@ function(add_operator TARGET level)
ARGS) ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA))
return() return()
endif() endif()
if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN))
return()
endif()
foreach(src ${args_SRCS}) foreach(src ${args_SRCS})
if(LITE_BUILD_TAILOR) if(LITE_BUILD_TAILOR)
......
...@@ -106,13 +106,12 @@ add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math ...@@ -106,13 +106,12 @@ add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math
# 4. training kernels # 4. training kernels
add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
if(LITE_WITH_TRAIN)
add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mean_grad_compute_arm ARM train SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_grad_compute_arm ARM train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(elementwise_grad_compute_arm ARM basic SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(elementwise_grad_compute_arm ARM train SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mul_grad_compute_arm ARM train SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sgd_compute_arm ARM train SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm)
endif()
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
......
...@@ -141,13 +141,12 @@ add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS}) ...@@ -141,13 +141,12 @@ add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS})
# 4. training op # 4. training op
add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
if (LITE_WITH_TRAIN)
add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) add_operator(mean_grad_op train SRCS mean_grad_op.cc DEPS ${op_DEPS})
add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) add_operator(activation_grad_ops train SRCS activation_grad_ops.cc DEPS ${op_DEPS})
add_operator(elementwise_grad_op extra SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) add_operator(elementwise_grad_op train SRCS elementwise_grad_ops.cc DEPS ${op_DEPS})
add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS}) add_operator(mul_grad_op train SRCS mul_grad_op.cc DEPS ${op_DEPS})
add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS}) add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS})
endif()
if (NOT LITE_WITH_X86) if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc lite_cc_test(test_fc_op SRCS fc_op_test.cc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册