未验证 提交 83d4b0e8 编写于 作者: Y Yan Chunwei 提交者: GitHub

make model_optimize_tool run on host (#1990)

上级 487d6089
...@@ -82,6 +82,7 @@ lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) ...@@ -82,6 +82,7 @@ lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF)
lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF IF LITE_WITH_PROFILE) lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF IF LITE_WITH_PROFILE)
lite_option(LITE_SHUTDOWN_LOG "Shutdown log system or not." OFF) lite_option(LITE_SHUTDOWN_LOG "Shutdown log system or not." OFF)
lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF) lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF)
lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF)
# publish options # publish options
lite_option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kernels and operators" OFF) lite_option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kernels and operators" OFF)
...@@ -104,6 +105,9 @@ if (LITE_ON_TINY_PUBLISH) ...@@ -104,6 +105,9 @@ if (LITE_ON_TINY_PUBLISH)
endif() endif()
include_directories("${PADDLE_SOURCE_DIR}") include_directories("${PADDLE_SOURCE_DIR}")
# the generated header files.
set(LITE_GENERATED_INCLUDE_DIR "${CMAKE_BINARY_DIR}")
include_directories("${LITE_GENERATED_INCLUDE_DIR}")
# for mobile # for mobile
if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
...@@ -34,33 +34,6 @@ elseif(SSE3_FOUND) ...@@ -34,33 +34,6 @@ elseif(SSE3_FOUND)
set(SIMD_FLAG ${SSE3_FLAG}) set(SIMD_FLAG ${SSE3_FLAG})
endif() endif()
if(WIN32)
# windows header option for all targets.
add_definitions(-D_XKEYCHECK_H)
# Use symbols instead of absolute path, reduce the cmake link command length.
SET(CMAKE_C_USE_RESPONSE_FILE_FOR_LIBRARIES 1)
SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_LIBRARIES 1)
SET(CMAKE_C_USE_RESPONSE_FILE_FOR_OBJECTS 1)
SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_OBJECTS 1)
SET(CMAKE_C_USE_RESPONSE_FILE_FOR_INCLUDES 1)
SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_INCLUDES 1)
SET(CMAKE_C_RESPONSE_FILE_LINK_FLAG "@")
SET(CMAKE_CXX_RESPONSE_FILE_LINK_FLAG "@")
# Specify the program to use when building static libraries
SET(CMAKE_C_CREATE_STATIC_LIBRARY "<CMAKE_AR> lib <TARGET> <LINK_FLAGS> <OBJECTS>")
SET(CMAKE_CXX_CREATE_STATIC_LIBRARY "<CMAKE_AR> lib <TARGET> <LINK_FLAGS> <OBJECTS>")
# set defination for the dll export
if (NOT MSVC)
message(FATAL "Windows build only support msvc. Which was binded by the nvcc compiler of NVIDIA.")
endif(NOT MSVC)
endif(WIN32)
if(WITH_PSLIB)
add_definitions(-DPADDLE_WITH_PSLIB)
endif()
if(LITE_WITH_CUDA) if(LITE_WITH_CUDA)
add_definitions(-DLITE_WITH_CUDA) add_definitions(-DLITE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU) add_definitions(-DEIGEN_USE_GPU)
...@@ -180,3 +153,8 @@ endif() ...@@ -180,3 +153,8 @@ endif()
if (LITE_ON_TINY_PUBLISH) if (LITE_ON_TINY_PUBLISH)
add_definitions("-DLITE_ON_TINY_PUBLISH") add_definitions("-DLITE_ON_TINY_PUBLISH")
endif() endif()
if (LITE_ON_MODEL_OPTIMIZE_TOOL)
add_definitions("-DLITE_ON_MODEL_OPTIMIZE_TOOL")
endif(LITE_ON_MODEL_OPTIMIZE_TOOL)
...@@ -185,6 +185,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -185,6 +185,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
SET(SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/protobuf-host") SET(SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/protobuf-host")
IF(BUILD_FOR_HOST) IF(BUILD_FOR_HOST)
# set for server compile.
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
set(HOST_C_COMPILER "${CMAKE_C_COMPILER}")
set(HOST_CXX_COMPILER "${CMAKE_CXX_COMPILER}")
endif()
SET(OPTIONAL_ARGS SET(OPTIONAL_ARGS
"-DCMAKE_C_COMPILER=${HOST_C_COMPILER}" "-DCMAKE_C_COMPILER=${HOST_C_COMPILER}"
"-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}" "-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}"
...@@ -276,7 +282,11 @@ IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) ...@@ -276,7 +282,11 @@ IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
ENDIF() ENDIF()
IF(NOT PROTOBUF_FOUND) IF(NOT PROTOBUF_FOUND)
build_protobuf(extern_protobuf FALSE) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
build_protobuf(extern_protobuf FALSE)
else()
build_protobuf(extern_protobuf TRUE)
endif()
SET(PROTOBUF_INCLUDE_DIR ${extern_protobuf_INCLUDE_DIR} SET(PROTOBUF_INCLUDE_DIR ${extern_protobuf_INCLUDE_DIR}
CACHE PATH "protobuf include directory." FORCE) CACHE PATH "protobuf include directory." FORCE)
......
...@@ -240,6 +240,21 @@ function(add_kernel TARGET device level) ...@@ -240,6 +240,21 @@ function(add_kernel TARGET device level)
return() return()
endif() endif()
if (LITE_ON_MODEL_OPTIMIZE_TOOL)
# the source list will collect for model_optimize_tool to fake kernel generation.
foreach(src ${args_SRCS})
file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n")
endforeach()
return()
endif()
# when compiling the model_optimize_tool, a source file with all the fake kernel definitions will be generated,
# no need to continue the compilation of the true kernel source.
if (LITE_ON_MODEL_OPTIMIZE_TOOL)
return()
endif(LITE_ON_MODEL_OPTIMIZE_TOOL)
if ("${device}" STREQUAL "Host") if ("${device}" STREQUAL "Host")
set(host_kernels "${host_kernels};${TARGET}" CACHE INTERNAL "") set(host_kernels "${host_kernels};${TARGET}" CACHE INTERNAL "")
endif() endif()
...@@ -274,6 +289,7 @@ function(add_kernel TARGET device level) ...@@ -274,6 +289,7 @@ function(add_kernel TARGET device level)
set(opencl_kernels "${opencl_kernels};${TARGET}" CACHE INTERNAL "") set(opencl_kernels "${opencl_kernels};${TARGET}" CACHE INTERNAL "")
endif() endif()
# the source list will collect for paddle_use_kernel.h code generation.
foreach(src ${args_SRCS}) foreach(src ${args_SRCS})
file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n")
endforeach() endforeach()
......
...@@ -69,12 +69,12 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) ...@@ -69,12 +69,12 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_full_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_full_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib"
COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/model_optimize_tool" "${INFER_LITE_PUBLISH_ROOT}/bin" #COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/model_optimize_tool" "${INFER_LITE_PUBLISH_ROOT}/bin"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/gen_code/paddle_code_generator" "${INFER_LITE_PUBLISH_ROOT}/bin" COMMAND cp "${CMAKE_BINARY_DIR}/lite/gen_code/paddle_code_generator" "${INFER_LITE_PUBLISH_ROOT}/bin"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/test_model_bin" "${INFER_LITE_PUBLISH_ROOT}/bin" COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/test_model_bin" "${INFER_LITE_PUBLISH_ROOT}/bin"
) )
if(NOT IOS) if(NOT IOS)
add_dependencies(publish_inference_cxx_lib model_optimize_tool) #add_dependencies(publish_inference_cxx_lib model_optimize_tool)
add_dependencies(publish_inference_cxx_lib paddle_code_generator) add_dependencies(publish_inference_cxx_lib paddle_code_generator)
add_dependencies(publish_inference_cxx_lib bundle_full_api) add_dependencies(publish_inference_cxx_lib bundle_full_api)
add_dependencies(publish_inference_cxx_lib bundle_light_api) add_dependencies(publish_inference_cxx_lib bundle_light_api)
......
...@@ -195,6 +195,14 @@ endif() ...@@ -195,6 +195,14 @@ endif()
if (LITE_ON_TINY_PUBLISH) if (LITE_ON_TINY_PUBLISH)
return() return()
endif() endif()
if (LITE_ON_MODEL_OPTIMIZE_TOOL)
message(STATUS "Compiling model_optimize_tool")
lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc
DEPS gflags kernel op optimizer mir_passes utils)
add_dependencies(model_optimize_tool op_list_h kernel_list_h all_kernel_faked_cc)
endif(LITE_ON_MODEL_OPTIMIZE_TOOL)
lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light
${ops} ${ops}
ARM_DEPS ${arm_kernels} ARM_DEPS ${arm_kernels}
...@@ -209,14 +217,14 @@ endif() ...@@ -209,14 +217,14 @@ endif()
# Some bins # Some bins
if(NOT IOS) if(NOT IOS)
lite_cc_binary(test_model_bin SRCS model_test.cc DEPS paddle_api_full paddle_api_light gflags lite_cc_binary(test_model_bin SRCS model_test.cc DEPS paddle_api_full paddle_api_light gflags utils
${ops} ${ops}
ARM_DEPS ${arm_kernels} ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels} NPU_DEPS ${npu_kernels}
CL_DEPS ${opencl_kernels} CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels} FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels}) X86_DEPS ${x86_kernels})
lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags utils
${ops} ${ops}
ARM_DEPS ${arm_kernels} ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels} NPU_DEPS ${npu_kernels}
...@@ -229,7 +237,3 @@ endif() ...@@ -229,7 +237,3 @@ endif()
#X86_DEPS operator #X86_DEPS operator
#DEPS light_api model_parser target_wrapper_host mir_passes #DEPS light_api model_parser target_wrapper_host mir_passes
#ARM_DEPS ${arm_kernels}) NPU_DEPS ${npu_kernels}) #ARM_DEPS ${arm_kernels}) NPU_DEPS ${npu_kernels})
lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc
DEPS paddle_api_full gflags
CL_DEPS ${opencl_kernels})
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
#include <gtest/gtest.h> #include <gtest/gtest.h>
#endif #endif
#include "all_kernel_faked.cc" // NOLINT
#include "lite/api/paddle_api.h" #include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h" #include "lite/api/paddle_use_passes.h"
#include "lite/core/op_registry.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
...@@ -33,6 +34,7 @@ DEFINE_string( ...@@ -33,6 +34,7 @@ DEFINE_string(
optimize_out_type, optimize_out_type,
"protobuf", "protobuf",
"store type of the output optimized model. protobuf/naive_buffer"); "store type of the output optimized model. protobuf/naive_buffer");
DEFINE_bool(display_kernels, false, "Display kernel information");
DEFINE_string(optimize_out, "", "path of the output optimized model"); DEFINE_string(optimize_out, "", "path of the output optimized model");
DEFINE_string(valid_targets, DEFINE_string(valid_targets,
"arm", "arm",
...@@ -43,12 +45,22 @@ DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); ...@@ -43,12 +45,22 @@ DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels");
namespace paddle { namespace paddle {
namespace lite_api { namespace lite_api {
//! Display the kernel information.
void DisplayKernels() {
LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString();
}
void Main() { void Main() {
if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) { if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) {
LOG(WARNING) LOG(WARNING)
<< "Load combined-param model. Option model_dir will be ignored"; << "Load combined-param model. Option model_dir will be ignored";
} }
if (FLAGS_display_kernels) {
DisplayKernels();
exit(0);
}
lite_api::CxxConfig config; lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir); config.set_model_dir(FLAGS_model_dir);
config.set_model_file(FLAGS_model_file); config.set_model_file(FLAGS_model_file);
...@@ -75,6 +87,7 @@ void Main() { ...@@ -75,6 +87,7 @@ void Main() {
CHECK(!valid_places.empty()) CHECK(!valid_places.empty())
<< "At least one target should be set, should set the " << "At least one target should be set, should set the "
"command argument 'valid_targets'"; "command argument 'valid_targets'";
if (FLAGS_prefer_int8_kernel) { if (FLAGS_prefer_int8_kernel) {
LOG(WARNING) << "Int8 mode is only support by ARM target"; LOG(WARNING) << "Int8 mode is only support by ARM target";
valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)}); valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)});
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \ #define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \
extern int touch_##op_type__##target__##precision__##layout__##alias__(); \ extern int touch_##op_type__##target__##precision__##layout__##alias__(); \
int op_type__##target__##precision__##layout__##alias__ \ int op_type__##target__##precision__##layout__##alias__##__use_lite_kernel \
__attribute__((unused)) = \ __attribute__((unused)) = \
touch_##op_type__##target__##precision__##layout__##alias__(); touch_##op_type__##target__##precision__##layout__##alias__();
......
...@@ -53,8 +53,16 @@ add_custom_command( ...@@ -53,8 +53,16 @@ add_custom_command(
${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h
OUTPUT ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h OUTPUT ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h
) )
# generate fake kernels for memory_optimize_tool
add_custom_command(
COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/create_fake_kernel_registry.py
${kernels_src_list}
${CMAKE_BINARY_DIR}/all_kernel_faked.cc
OUTPUT ${CMAKE_BINARY_DIR}/all_kernel_faked.cc
)
add_custom_target(op_list_h DEPENDS ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h) add_custom_target(op_list_h DEPENDS ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_ops.h)
add_custom_target(kernel_list_h DEPENDS ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h) add_custom_target(kernel_list_h DEPENDS ${CMAKE_SOURCE_DIR}/lite/api/paddle_use_kernels.h)
add_custom_target(all_kernel_faked_cc DEPENDS ${CMAKE_BINARY_DIR}/all_kernel_faked.cc)
#----------------------------------------------- NOT CHANGE ----------------------------------------------- #----------------------------------------------- NOT CHANGE -----------------------------------------------
lite_cc_library(kernel SRCS kernel.cc DEPS context type_system target_wrapper any op_params tensor lite_cc_library(kernel SRCS kernel.cc DEPS context type_system target_wrapper any op_params tensor
......
...@@ -356,7 +356,10 @@ class ContextScheduler { ...@@ -356,7 +356,10 @@ class ContextScheduler {
break; break;
#endif #endif
default: default:
#ifndef LITE_ON_MODEL_OPTIMIZE_TOOL
LOG(FATAL) << "unsupported target " << TargetToStr(target); LOG(FATAL) << "unsupported target " << TargetToStr(target);
#endif
break;
} }
return ctx; return ctx;
} }
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
#pragma once #pragma once
#include <list> #include <list>
#include <map>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <tuple>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -26,6 +28,7 @@ ...@@ -26,6 +28,7 @@
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
#include "lite/utils/macros.h"
using LiteType = paddle::lite::Type; using LiteType = paddle::lite::Type;
...@@ -159,6 +162,10 @@ class KernelRegistry final { ...@@ -159,6 +162,10 @@ class KernelRegistry final {
auto *reg = varient.template get<kernel_registor_t *>(); auto *reg = varient.template get<kernel_registor_t *>();
CHECK(reg) << "Can not be empty of " << name; CHECK(reg) << "Can not be empty of " << name;
reg->Register(name, std::move(creator)); reg->Register(name, std::move(creator));
#ifdef LITE_ON_MODEL_OPTIMIZE_TOOL
kernel_info_map_[name].push_back(
std::make_tuple(Target, Precision, Layout));
#endif // LITE_ON_MODEL_OPTIMIZE_TOOL
} }
template <TargetType Target, template <TargetType Target,
...@@ -190,22 +197,42 @@ class KernelRegistry final { ...@@ -190,22 +197,42 @@ class KernelRegistry final {
} }
std::string DebugString() const { std::string DebugString() const {
#ifndef LITE_ON_MODEL_OPTIMIZE_TOOL
return "No more debug info";
#else // LITE_ON_MODEL_OPTIMIZE_TOOL
STL::stringstream ss; STL::stringstream ss;
ss << "KernelCreator<host, float>:\n"; ss << "\n";
constexpr TargetType tgt = TARGET(kHost); ss << "Count of kernel kinds: ";
constexpr PrecisionType dt = PRECISION(kFloat); int count = 0;
constexpr DataLayoutType lt = DATALAYOUT(kNCHW); for (auto &item : kernel_info_map_) {
constexpr DataLayoutType kany = DATALAYOUT(kAny); for (auto &kernel : item.second) ++count;
using kernel_registor_t = KernelRegistryForTarget<tgt, dt, lt>; }
auto *reg = registries_[GetKernelOffset<tgt, dt, kany>()] ss << count << "\n";
.template get<kernel_registor_t *>();
ss << reg->DebugString() << "\n"; ss << "Count of registered kernels: " << kernel_info_map_.size() << "\n";
for (auto &item : kernel_info_map_) {
ss << "op: " << item.first << "\n";
for (auto &kernel : item.second) {
ss << " - (" << TargetToStr(std::get<0>(kernel)) << ",";
ss << PrecisionToStr(std::get<1>(kernel)) << ",";
ss << DataLayoutToStr(std::get<2>(kernel));
ss << ")";
ss << "\n";
}
}
return ss.str(); return ss.str();
return ""; #endif // LITE_ON_MODEL_OPTIMIZE_TOOL
} }
private: private:
mutable std::vector<any_kernel_registor_t> registries_; mutable std::vector<any_kernel_registor_t> registries_;
#ifndef LITE_ON_TINY_PUBLISH
mutable std::map<
std::string,
std::vector<std::tuple<TargetType, PrecisionType, DataLayoutType>>>
kernel_info_map_;
#endif
}; };
template <TargetType target, template <TargetType target,
......
if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))
return()
endif()
message(STATUS "compile with lite ARM kernels")
add_kernel(fc_compute_arm ARM basic SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fc_compute_arm ARM basic SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(activation_compute_arm ARM basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_compute_arm ARM basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mul_compute_arm ARM basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mul_compute_arm ARM basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -73,6 +67,15 @@ add_kernel(fill_constant_compute_arm ARM extra SRCS fill_constant_compute.cc DEP ...@@ -73,6 +67,15 @@ add_kernel(fill_constant_compute_arm ARM extra SRCS fill_constant_compute.cc DEP
add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm)
# NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered
# to the model_optimize_tool.
if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))
return()
endif()
message(STATUS "compile with lite ARM kernels")
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
......
...@@ -82,28 +82,25 @@ void UnaryLogicalCompute<Functor>::Run() { ...@@ -82,28 +82,25 @@ void UnaryLogicalCompute<Functor>::Run() {
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(
logical_xor, REGISTER_LITE_KERNEL(logical_xor,
kARM, kARM,
kFloat, kFloat,
kNCHW, kNCHW,
paddle::lite::kernels::arm::BinaryLogicalCompute< paddle::lite::kernels::arm::BinaryLogicalCompute<
paddle::lite::kernels::arm::_LogicalXorFunctor>, paddle::lite::kernels::arm::_LogicalXorFunctor>,
// paddle::lite::kernels::arm::BinaryLogicalCompute<paddle::lite::kernels::arm::_LogicalXorFunctor<bool>>, def)
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(logical_and,
logical_and, kARM,
kARM, kFloat,
kFloat, kNCHW,
kNCHW, paddle::lite::kernels::arm::BinaryLogicalCompute<
// paddle::lite::kernels::arm::BinaryLogicalCompute<paddle::lite::kernels::arm::_LogicalAndFunctor<bool>>, paddle::lite::kernels::arm::_LogicalAndFunctor>,
paddle::lite::kernels::arm::BinaryLogicalCompute< def)
paddle::lite::kernels::arm::_LogicalAndFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
......
...@@ -5,5 +5,5 @@ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kerne ...@@ -5,5 +5,5 @@ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kerne
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) #lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any)
#lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any) #lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any)
if(NOT LITE_WITH_X86)
return()
endif()
# lite_cc_library(activation_compute_x86 SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op) # lite_cc_library(activation_compute_x86 SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op)
# lite_cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps})
...@@ -38,6 +34,10 @@ add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_ ...@@ -38,6 +34,10 @@ add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_
add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling)
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
if(NOT LITE_WITH_X86)
return()
endif()
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86) lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86)
lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86) lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86)
......
...@@ -25,20 +25,20 @@ REGISTER_LITE_KERNEL(mul, ...@@ -25,20 +25,20 @@ REGISTER_LITE_KERNEL(mul,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
#ifdef LITE_WITH_TRAIN // #ifdef LITE_WITH_TRAIN
REGISTER_LITE_KERNEL(mul_grad, // REGISTER_LITE_KERNEL(mul_grad,
kX86, // kX86,
kFloat, // kFloat,
kNCHW, // kNCHW,
paddle::lite::kernels::x86::MulGradCompute<float>, // paddle::lite::kernels::x86::MulGradCompute<float>,
def) // def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) // .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) // .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput(paddle::framework::GradVarName("Out"), // .BindInput(paddle::framework::GradVarName("Out"),
{LiteType::GetTensorTy(TARGET(kX86))}) // {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("X"), // .BindOutput(paddle::framework::GradVarName("X"),
{LiteType::GetTensorTy(TARGET(kX86))}) // {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("Y"), // .BindOutput(paddle::framework::GradVarName("Y"),
{LiteType::GetTensorTy(TARGET(kX86))}) // {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); // .Finalize();
#endif // #endif
...@@ -224,6 +224,7 @@ function build_test_server { ...@@ -224,6 +224,7 @@ function build_test_server {
build build
test_server test_server
test_model_optimize_tool_compile
} }
function build_test_train { function build_test_train {
...@@ -393,20 +394,27 @@ function test_arm_model { ...@@ -393,20 +394,27 @@ function test_arm_model {
adb -s emulator-${port} shell "${adb_work_dir}/${test_name} --model_dir=$adb_model_path" adb -s emulator-${port} shell "${adb_work_dir}/${test_name} --model_dir=$adb_model_path"
} }
function _test_model_optimize_tool { # function _test_model_optimize_tool {
local port=$1 # local port=$1
local remote_model_path=$ADB_WORK_DIR/lite_naive_model # local remote_model_path=$ADB_WORK_DIR/lite_naive_model
local remote_test=$ADB_WORK_DIR/model_optimize_tool # local remote_test=$ADB_WORK_DIR/model_optimize_tool
local adb="adb -s emulator-${port}" # local adb="adb -s emulator-${port}"
# make model_optimize_tool -j$NUM_CORES_FOR_COMPILE
# local test_path=$(find . -name model_optimize_tool | head -n1)
# local model_path=$(find . -name lite_naive_model | head -n1)
# $adb push ${test_path} ${ADB_WORK_DIR}
# $adb shell mkdir -p $remote_model_path
# $adb push $model_path/* $remote_model_path
# $adb shell $remote_test --model_dir $remote_model_path --optimize_out ${remote_model_path}.opt \
# --valid_targets "arm"
# }
function test_model_optimize_tool_compile {
cd $workspace
cd build
cmake .. -DWITH_LITE=ON -DLITE_ON_MODEL_OPTIMIZE_TOOL=ON -DWITH_TESTING=OFF -DLITE_BUILD_EXTRA=ON
make model_optimize_tool -j$NUM_CORES_FOR_COMPILE make model_optimize_tool -j$NUM_CORES_FOR_COMPILE
local test_path=$(find . -name model_optimize_tool | head -n1)
local model_path=$(find . -name lite_naive_model | head -n1)
$adb push ${test_path} ${ADB_WORK_DIR}
$adb shell mkdir -p $remote_model_path
$adb push $model_path/* $remote_model_path
$adb shell $remote_test --model_dir $remote_model_path --optimize_out ${remote_model_path}.opt \
--valid_targets "arm"
} }
function _test_paddle_code_generator { function _test_paddle_code_generator {
...@@ -558,8 +566,8 @@ function test_arm { ...@@ -558,8 +566,8 @@ function test_arm {
# test finally # test finally
test_arm_api $port test_arm_api $port
_test_model_optimize_tool $port # _test_model_optimize_tool $port
_test_paddle_code_generator $port # _test_paddle_code_generator $port
} }
function prepare_emulator { function prepare_emulator {
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
class SyntaxParser(object):
def __init__(self, str):
self.str = str
self.cur_pos = 0
self.N = len(self.str)
self.token = ''
def eat_char(self):
self.cur_pos += 1
def eat_str(self):
'''
"xx"
'''
self.token = ''
assert self.cur == '"';
self.cur_pos += 1;
assert self.cur_pos < self.N
while self.cur != '"':
self.token += self.cur
self.cur_pos += 1
assert self.cur_pos < self.N
assert self.cur == '"'
self.cur_pos += 1
#logging.warning('get: %s' % self.token)
def eat_word(self):
self.token = ''
str = ''
while self.cur.isalnum() or self.cur in ('_', ':',):
self.token += self.cur
self.forward()
#logging.warning('get: %s' % self.token)
def eat_left_parentheses(self):
'''
(
'''
self.assert_is('(')
self.token = '('
self.forward()
#logging.warning('get: %s' % self.token)
def eat_right_parentheses(self):
'''
)
'''
self.assert_is(')')
self.token = ')'
self.forward()
#logging.warning('get: %s' % self.token)
def eat_left_brace(self):
'''
{
'''
self.assert_is('{')
self.token = '{'
self.forward()
#logging.warning('get: %s' % self.token)
def eat_right_brace(self):
'''
}
'''
self.assert_is('}')
self.token = '}'
self.forward()
#logging.warning('get: %s' % self.token)
def eat_comma(self):
'''
,
'''
self.assert_is(',')
self.token = ','
self.forward()
#logging.warning('get: %s' % self.token)
def eat_spaces(self):
'''
eat space like string.
'''
while self.cur_pos < len(self.str):
if self.cur in (' ', '\t', '\n'):
self.forward()
else:
break
def eat_point(self):
'''
.
'''
self.assert_is('.')
self.token = '.'
self.forward()
#logging.warning('get: %s' % self.token)
def eat_any_but_brace(self):
'''
anything but {}
'''
start = self.cur_pos
while self.cur not in ('{', '}'):
self.cur_pos += 1
self.token = self.str[start:self.cur_pos]
#logging.warning('get: %s' % self.token)
def eat_semicolon(self):
'''
;
'''
self.assert_is(';')
self.token = ';'
self.forward()
#logging.warning('get: %s' % self.token)
def assert_is(self, w):
assert self.cur == w, "token should be %s, but get %s" % (w, self.cur)
@property
def cur(self):
assert self.cur_pos < self.N
return self.str[self.cur_pos]
#logging.warning('get: %s' % self.token)
def forward(self):
self.cur_pos += 1
class IO:
def __init__(self):
self.name = ''
self.type = ''
def __repr__(self):
return "- %s: %s" % (self.name, self.type)
class KernelRegistry:
def __init__(self):
self.op_type = ''
self.target = ''
self.precision = ''
self.data_layout = ''
self.class_ = ''
self.alias = ''
self.inputs = []
self.outputs = []
def __repr__(self):
str = "Kernel({op_type}, {target}, {precision}, {data_layout}, {alias}):".format(
op_type = self.op_type,
target = self.target,
precision = self.precision,
data_layout = self.data_layout,
alias = self.alias,
)
str += '\n' + '\n'.join(repr(io) for io in self.inputs)
str += '\n' + '\n'.join(repr(io) for io in self.outputs)
str += '\n'
return str
class RegisterLiteKernelParser(SyntaxParser):
KEYWORD = 'REGISTER_LITE_KERNEL'
def __init__(self, str):
super(RegisterLiteKernelParser, self).__init__(str)
self.kernels = []
def parse(self):
find_registry_command = False
while self.cur_pos < len(self.str):
start = self.str.find(self.KEYWORD, self.cur_pos)
if start != -1:
#print 'str ', start, self.str[start-2: start]
if start != 0 and '/' in self.str[start-2: start]:
'''
skip commented code
'''
self.cur_pos = start + 1
continue
self.cur_pos = start
k = KernelRegistry()
self.kernels.append(self.parse_register(k))
else:
break
def eat_class(self):
start = self.cur_pos
self.eat_word()
stack = ''
if self.cur == '<':
stack = stack + '<'
self.forward()
while stack:
if self.cur == '<':
stack = stack + '<'
elif self.cur == '>':
stack = stack[1:]
else:
pass
self.forward()
self.token = self.str[start:self.cur_pos]
def parse_register(self, k):
self.eat_word()
assert self.token == self.KEYWORD
self.eat_spaces()
self.eat_left_parentheses()
self.eat_spaces()
self.eat_word()
k.op_type = self.token
self.eat_comma()
self.eat_spaces()
self.eat_word()
k.target = self.token
self.eat_comma()
self.eat_spaces()
self.eat_word()
k.precision = self.token
self.eat_comma()
self.eat_spaces()
self.eat_word()
k.data_layout = self.token
self.eat_comma()
self.eat_spaces()
self.eat_class()
k.class_ = self.token
self.eat_comma()
self.eat_spaces()
self.eat_word()
k.alias = self.token
self.eat_spaces()
self.eat_right_parentheses()
self.eat_spaces()
def eat_io(is_input, io):
self.eat_left_parentheses()
self.eat_str()
io.name = self.token
self.eat_comma()
self.eat_spaces()
self.eat_left_brace()
self.eat_any_but_brace()
io.type = self.token
self.eat_right_brace()
self.eat_spaces()
self.eat_right_parentheses()
self.eat_spaces()
# eat input and output
while self.cur_pos < len(self.str):
self.eat_point()
self.eat_spaces()
self.eat_word()
assert self.token in ('BindInput', 'BindOutput', 'Finalize')
io = IO()
if self.token == 'BindInput':
eat_io(True, io)
k.inputs.append(io)
elif self.token == 'BindOutput':
eat_io(False, io)
k.outputs.append(io)
else:
self.eat_left_parentheses()
self.eat_right_parentheses()
self.eat_semicolon()
self.eat_spaces()
return k
break
if __name__ == '__main__':
with open('/home/chunwei/project2/Paddle-Lite/lite/kernels/arm/activation_compute.cc') as f:
c = f.read()
kernel_parser = RegisterLiteKernelParser(c)
kernel_parser.parse()
for k in kernel_parser.kernels:
print k
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import logging
from ast import RegisterLiteKernelParser
from utils import *
ops_list_path = sys.argv[1]
dest_path = sys.argv[2]
out_lines = [
'#pragma once',
'#include "lite/core/op_registry.h"',
'#include "lite/core/kernel.h"',
'#include "lite/core/type_system.h"',
'',
]
fake_kernel = '''
namespace paddle {
namespace lite {
class %s : public KernelLite<TARGET(%s), PRECISION(%s), DATALAYOUT(%s)> {
public:
void PrepareForRun() override {}
void Run() override {}
virtual ~%s() = default;
};
} // namespace lite
} // namespace paddle
'''
with open(ops_list_path) as f:
paths = set([path for path in f])
for path in paths:
print 'path', path
with open(path.strip()) as g:
c = g.read()
kernel_parser = RegisterLiteKernelParser(c)
kernel_parser.parse()
for k in kernel_parser.kernels:
kernel_name = "{op_type}_{target}_{precision}_{data_layout}_{alias}_class".format(
op_type = k.op_type,
target = k.target,
precision = k.precision,
data_layout = k.data_layout,
alias = k.alias,
)
kernel_define = fake_kernel % (
kernel_name,
k.target,
k.precision,
k.data_layout,
kernel_name,
)
out_lines.append(kernel_define)
out_lines.append("")
key = "REGISTER_LITE_KERNEL(%s, %s, %s, %s, %s, %s)" % (
k.op_type,
k.target,
k.precision,
k.data_layout,
'::paddle::lite::' + kernel_name,
k.alias,
)
out_lines.append(key)
for input in k.inputs:
io = ' .BindInput("%s", {%s})' % (input.name, input.type)
out_lines.append(io)
for output in k.outputs:
io = ' .BindOutput("%s", {%s})' % (output.name, output.type)
out_lines.append(io)
out_lines.append(" .Finalize();")
out_lines.append("")
out_lines.append(gen_use_kernel_statement(k.op_type, k.target, k.precision, k.data_layout, k.alias))
with open(dest_path, 'w') as f:
logging.info("write kernel list to %s" % dest_path)
f.write('\n'.join(out_lines))
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import sys import sys
import logging import logging
from ast import RegisterLiteKernelParser
ops_list_path = sys.argv[1] ops_list_path = sys.argv[1]
dest_path = sys.argv[2] dest_path = sys.argv[2]
...@@ -24,56 +25,25 @@ out_lines = [ ...@@ -24,56 +25,25 @@ out_lines = [
'', '',
] ]
left_pattern = 'REGISTER_LITE_KERNEL('
right_pattern = ')'
def find_right_pattern(context, start):
if start >= len(context): return -1
fake_left_num = 0
while start < len(context):
if context[start] == right_pattern:
if fake_left_num == 0:
return start
else:
fake_left_num -= 1
elif context[start] == '(':
fake_left_num += 1
start += 1
return -1
lines = set()
with open(ops_list_path) as f: with open(ops_list_path) as f:
for line in f: paths = set([path for path in f])
lines.add(line.strip()) for path in paths:
with open(path.strip()) as g:
for line in lines: print 'path: ', path
path = line.strip() c = g.read()
kernel_parser = RegisterLiteKernelParser(c)
status = '' kernel_parser.parse()
with open(path) as g:
context = ''.join([item.strip() for item in g]) for k in kernel_parser.kernels:
index = 0 key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % (
cxt_len = len(context) k.op_type,
while index < cxt_len and index >= 0: k.target,
left_index = context.find(left_pattern, index) k.precision,
if left_index < 0: break k.data_layout,
right_index = find_right_pattern(context, left_index+len(left_pattern)) k.alias,
if right_index < 0: )
raise ValueError("Left Pattern and Right Pattern does not match") out_lines.append(key)
tmp = context[left_index+len(left_pattern) : right_index]
index = right_index + 1
if tmp.startswith('/'): continue
fields = [item.strip() for item in tmp.split(',')]
if len(fields) < 6:
raise ValueError("Invalid REGISTER_LITE_KERNEL format")
op, target, precision, layout = fields[:4]
alias = fields[-1]
key = "USE_LITE_KERNEL(%s, %s, %s, %s, %s);" % (
op, target, precision, layout, alias)
if "_grad" in key: continue
out_lines.append(key)
with open(dest_path, 'w') as f: with open(dest_path, 'w') as f:
logging.info("write kernel list to %s" % dest_path) logging.info("write kernel list to %s" % dest_path)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def gen_use_kernel_statement(op_type, target, precision, layout, alias):
return 'USE_LITE_KERNEL(%s, %s, %s, %s, %s);' %(
op_type, target, precision, layout, alias
)
...@@ -3,23 +3,23 @@ ...@@ -3,23 +3,23 @@
# else() # else()
# endif() # endif()
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK OR LITE_ON_MODEL_OPTIMIZE_TOOL)
lite_cc_library(logging SRCS logging.cc) lite_cc_library(logging SRCS logging.cc)
set(utils_DEPS logging) set(utils_DEPS logging)
lite_cc_test(test_logging SRCS logging_test.cc DEPS ${utils_DEPS}) lite_cc_test(test_logging SRCS logging_test.cc DEPS ${utils_DEPS})
else() else()
set(utils_DEPS glog) set(utils_DEPS glog)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) endif()
lite_cc_test(test_varient SRCS varient_test.cc DEPS utils) lite_cc_test(test_varient SRCS varient_test.cc DEPS utils)
lite_cc_library(any SRCS any.cc) lite_cc_library(any SRCS any.cc)
if(LITE_ON_TINY_PUBLISH) if(LITE_ON_TINY_PUBLISH OR LITE_ON_MODEL_OPTIMIZE_TOOL)
lite_cc_library(stream SRCS replace_stl/stream.cc) lite_cc_library(stream SRCS replace_stl/stream.cc)
endif() endif()
#lite_cc_library(utils SRCS cp_logging.cc string.cc DEPS ${utils_DEPS} any) #lite_cc_library(utils SRCS cp_logging.cc string.cc DEPS ${utils_DEPS} any)
if(LITE_ON_TINY_PUBLISH) if(LITE_ON_TINY_PUBLISH OR LITE_ON_MODEL_OPTIMIZE_TOOL)
lite_cc_library(utils SRCS string.cc DEPS ${utils_DEPS} any stream) lite_cc_library(utils SRCS string.cc DEPS ${utils_DEPS} any stream)
else() else()
lite_cc_library(utils SRCS string.cc DEPS ${utils_DEPS} any) lite_cc_library(utils SRCS string.cc DEPS ${utils_DEPS} any)
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #if defined(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || \
defined(LITE_ON_MODEL_OPTIMIZE_TOOL)
#include "lite/utils/logging.h" #include "lite/utils/logging.h"
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include <glog/logging.h> #include <glog/logging.h>
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
#include "lite/utils/logging.h" #include "lite/utils/logging.h"
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #if defined(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || \
defined(LITE_ON_MODEL_OPTIMIZE_TOOL)
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
namespace paddle { namespace paddle {
...@@ -48,7 +49,7 @@ void gen_log(STL::ostream& log_stream_, ...@@ -48,7 +49,7 @@ void gen_log(STL::ostream& log_stream_,
<< tv.tv_usec / 1000 << " "; << tv.tv_usec / 1000 << " ";
if (len > kMaxLen) { if (len > kMaxLen) {
log_stream_ << "..." << file + len - kMaxLen << " " << func << ":" << lineno log_stream_ << "..." << file + len - kMaxLen << ":" << lineno << " " << func
<< "] "; << "] ";
} else { } else {
log_stream_ << file << " " << func << ":" << lineno << "] "; log_stream_ << file << " " << func << ":" << lineno << "] ";
......
...@@ -81,7 +81,7 @@ void gen_log(STL::ostream& log_stream_, ...@@ -81,7 +81,7 @@ void gen_log(STL::ostream& log_stream_,
const char* func, const char* func,
int lineno, int lineno,
const char* level, const char* level,
const int kMaxLen = 20); const int kMaxLen = 40);
// LogMessage // LogMessage
class LogMessage { class LogMessage {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册