提交 47ebc8c4 编写于 作者: J jiweibo

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into stream_manage

...@@ -98,6 +98,7 @@ lite_option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OF ...@@ -98,6 +98,7 @@ lite_option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OF
lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF)
lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF) lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF)
lite_option(LITE_WITH_LOG "Enable log printing or not." ON) lite_option(LITE_WITH_LOG "Enable log printing or not." ON)
lite_option(LITE_WITH_EXCEPTION "Enable throwing the exception when error occurs in lite" OFF)
lite_option(LITE_WITH_NVTX "Enable nvtx or not, please enable LITE_WITH_CUDA first." OFF) lite_option(LITE_WITH_NVTX "Enable nvtx or not, please enable LITE_WITH_CUDA first." OFF)
lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF) lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF)
lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF) lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF)
......
...@@ -190,6 +190,10 @@ if (LITE_WITH_LOG) ...@@ -190,6 +190,10 @@ if (LITE_WITH_LOG)
add_definitions("-DLITE_WITH_LOG") add_definitions("-DLITE_WITH_LOG")
endif() endif()
if (LITE_WITH_EXCEPTION)
add_definitions("-DLITE_WITH_EXCEPTION")
endif()
if (LITE_ON_TINY_PUBLISH) if (LITE_ON_TINY_PUBLISH)
add_definitions("-DLITE_ON_TINY_PUBLISH") add_definitions("-DLITE_ON_TINY_PUBLISH")
endif() endif()
......
...@@ -80,6 +80,17 @@ if (ARM_TARGET_LANG STREQUAL "clang") ...@@ -80,6 +80,17 @@ if (ARM_TARGET_LANG STREQUAL "clang")
elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7") elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7")
set(triple arm-v7a-linux-android) set(triple arm-v7a-linux-android)
set(LITE_WITH_OPENMP OFF CACHE STRING "Due to libomp's bug(For ARM64, it has been fixed by https://reviews.llvm.org/D19879, but still exists on ARM32), disable OpenMP on armv7 when cross-compiling using Clang" FORCE) set(LITE_WITH_OPENMP OFF CACHE STRING "Due to libomp's bug(For ARM64, it has been fixed by https://reviews.llvm.org/D19879, but still exists on ARM32), disable OpenMP on armv7 when cross-compiling using Clang" FORCE)
if(ANDROID_STL_TYPE MATCHES "^c\\+\\+_")
# Use CMAKE_CXX_STANDARD_LIBRARIES_INIT to ensure libunwind and libc++ is linked in the right order
set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libunwind.a")
if(ANDROID_STL_TYPE STREQUAL "c++_shared")
set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libc++_shared.so")
elseif(ANDROID_STL_TYPE STREQUAL "c++_static")
set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libc++_static.a")
else()
message(FATAL_ERROR "Invalid Android STL TYPE: ${ANDROID_STL_TYPE}.")
endif()
endif()
else() else()
message(FATAL_ERROR "Clang do not support this ${ARM_TARGET_ARCH_ABI}, use armv8 or armv7") message(FATAL_ERROR "Clang do not support this ${ARM_TARGET_ARCH_ABI}, use armv8 or armv7")
endif() endif()
......
...@@ -23,6 +23,21 @@ if(ANDROID) ...@@ -23,6 +23,21 @@ if(ANDROID)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -llog -fPIC") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -llog -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -llog -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -llog -fPIC")
# Don't re-export libgcc symbols
set(REMOVE_ATOMIC_GCC_SYMBOLS "-Wl,--exclude-libs,libatomic.a -Wl,--exclude-libs,libgcc.a")
set(CMAKE_SHARED_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_SHARED_LINKER_FLAGS}")
set(CMAKE_MODULE_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_MODULE_LINKER_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_EXE_LINKER_FLAGS}")
# Only the libunwind.a from clang(with libc++) provide C++ exception handling support for 32-bit ARM
# Refer to https://android.googlesource.com/platform/ndk/+/master/docs/BuildSystemMaintainers.md#Unwinding
if (ARM_TARGET_LANG STREQUAL "clang" AND ARM_TARGET_ARCH_ABI STREQUAL "armv7" AND ANDROID_STL_TYPE MATCHES "^c\\+\\+_")
set(REMOVE_UNWIND_SYMBOLS "-Wl,--exclude-libs,libunwind.a")
set(CMAKE_SHARED_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_SHARED_LINKER_FLAGS}")
set(CMAKE_MODULE_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_MODULE_LINKER_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_EXE_LINKER_FLAGS}")
endif()
endif() endif()
if(ARMLINUX) if(ARMLINUX)
...@@ -59,14 +74,13 @@ function(check_linker_flag) ...@@ -59,14 +74,13 @@ function(check_linker_flag)
endfunction() endfunction()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
if((LITE_WITH_OPENCL AND (ARM_TARGET_LANG STREQUAL "clang")) OR LITE_WITH_PYTHON OR LITE_WITH_EXCEPTION OR (NOT LITE_ON_TINY_PUBLISH))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fasynchronous-unwind-tables -funwind-tables")
else ()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -fno-asynchronous-unwind-tables -fno-unwind-tables")
endif()
if (LITE_ON_TINY_PUBLISH) if (LITE_ON_TINY_PUBLISH)
if((NOT LITE_WITH_PYTHON)) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions")
endif()
if(LITE_WITH_OPENCL AND (ARM_TARGET_LANG STREQUAL "clang"))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer -fno-asynchronous-unwind-tables -fno-unwind-tables")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -ffunction-sections") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -ffunction-sections")
check_linker_flag(-Wl,--gc-sections) check_linker_flag(-Wl,--gc-sections)
endif() endif()
......
...@@ -54,6 +54,11 @@ find_library(NPU_DDK_IR_BUILD_FILE NAMES hiai_ir_build ...@@ -54,6 +54,11 @@ find_library(NPU_DDK_IR_BUILD_FILE NAMES hiai_ir_build
PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH} PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH}
NO_DEFAULT_PATH) NO_DEFAULT_PATH)
# Added in HiAI DDK 320 or later version
find_library(NPU_DDK_HCL_FILE NAMES hcl
PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH}
NO_DEFAULT_PATH)
if(NOT NPU_DDK_HIAI_FILE) if(NOT NPU_DDK_HIAI_FILE)
message(FATAL_ERROR "Can not find NPU_DDK_HIAI_FILE in ${NPU_DDK_ROOT}") message(FATAL_ERROR "Can not find NPU_DDK_HIAI_FILE in ${NPU_DDK_ROOT}")
else() else()
...@@ -78,5 +83,13 @@ else() ...@@ -78,5 +83,13 @@ else()
set_property(TARGET npu_ddk_ir_build PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_BUILD_FILE}) set_property(TARGET npu_ddk_ir_build PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_BUILD_FILE})
endif() endif()
set(npu_runtime_libs npu_ddk_hiai CACHE INTERNAL "npu ddk runtime libs") if(NOT NPU_DDK_HCL_FILE)
# message(FATAL_ERROR "Can not find NPU_DDK_HCL_FILE in ${NPU_DDK_ROOT}")
else()
message(STATUS "Found NPU_DDK HCL Library: ${NPU_DDK_HCL_FILE}")
add_library(npu_ddk_hcl SHARED IMPORTED GLOBAL)
set_property(TARGET npu_ddk_hcl PROPERTY IMPORTED_LOCATION ${NPU_DDK_HCL_FILE})
endif()
set(npu_runtime_libs npu_ddk_hiai npu_ddk_hcl CACHE INTERNAL "npu ddk runtime libs")
set(npu_builder_libs npu_ddk_ir npu_ddk_ir_build CACHE INTERNAL "npu ddk builder libs") set(npu_builder_libs npu_ddk_ir npu_ddk_ir_build CACHE INTERNAL "npu ddk builder libs")
...@@ -45,6 +45,7 @@ if (WITH_TESTING) ...@@ -45,6 +45,7 @@ if (WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "MobileNetV1_quant.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "MobileNetV1_quant.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "transformer_with_mask_fp32.tar.gz")
endif() endif()
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz")
......
...@@ -37,8 +37,7 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -37,8 +37,7 @@ void Predictor::SaveModel(const std::string &dir,
if (!program_) { if (!program_) {
GenRuntimeProgram(); GenRuntimeProgram();
} }
program_->SaveOpInfosToProgram(program_desc_.get()); program_->SaveToProgram(program_desc_);
program_->UpdateVarsOfProgram(program_desc_.get());
switch (model_type) { switch (model_type) {
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
SaveModelPb(dir, *program_->exec_scope(), *program_desc_.get(), true); SaveModelPb(dir, *program_->exec_scope(), *program_desc_.get(), true);
...@@ -58,18 +57,22 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -58,18 +57,22 @@ void Predictor::SaveModel(const std::string &dir,
void Predictor::SaveOpKernelInfo(const std::string &model_dir) { void Predictor::SaveOpKernelInfo(const std::string &model_dir) {
std::set<std::string> ops_info; std::set<std::string> ops_info;
std::set<std::string> kernels_info; std::set<std::string> kernels_info;
const auto &instructions_ = program_->instructions(); auto block_size = program_->block_size();
for (auto &node : instructions_) { for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
const auto &insts = program_->instructions(block_idx);
for (auto &inst : insts) {
// parse op type infomation // parse op type infomation
auto op = node.op()->op_info(); auto op = inst.op()->op_info();
ops_info.insert(op->Type()); ops_info.insert(op->Type());
// parse kernel type information // parse kernel type information
std::string kernel_type_str = std::string kernel_type_str =
node.kernel()->op_type() + "," + TargetRepr(node.kernel()->target()) + inst.kernel()->op_type() + "," + TargetRepr(inst.kernel()->target()) +
"," + PrecisionRepr(node.kernel()->precision()) + "," + "," + PrecisionRepr(inst.kernel()->precision()) + "," +
DataLayoutRepr(node.kernel()->layout()) + "," + node.kernel()->alias(); DataLayoutRepr(inst.kernel()->layout()) + "," +
inst.kernel()->alias();
kernels_info.insert(kernel_type_str); kernels_info.insert(kernel_type_str);
} }
}
// get souce_file name from op type and kernel type // get souce_file name from op type and kernel type
auto op2pathmap = OpKernelInfoCollector::Global().GetOp2PathDict(); auto op2pathmap = OpKernelInfoCollector::Global().GetOp2PathDict();
...@@ -170,9 +173,9 @@ void Predictor::PrepareFeedFetch() { ...@@ -170,9 +173,9 @@ void Predictor::PrepareFeedFetch() {
std::vector<const cpp::OpDesc *> feeds; std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs; std::vector<const cpp::OpDesc *> fetchs;
const auto &insts = program_->instructions(); const auto &insts = program_->instructions(kRootBlockIdx);
for (size_t i = 0; i < program_->num_instructions(); i++) { for (auto &inst : insts) {
const auto &op = insts[i].op()->op_info(); const auto &op = inst.op()->op_info();
if (op->Type() == "feed") { if (op->Type() == "feed") {
feeds.push_back(op); feeds.push_back(op);
} else if (op->Type() == "fetch") { } else if (op->Type() == "fetch") {
...@@ -255,7 +258,6 @@ void Predictor::Build(const lite_api::CxxConfig &config, ...@@ -255,7 +258,6 @@ void Predictor::Build(const lite_api::CxxConfig &config,
} else { } else {
LOG(INFO) << "Load model from file."; LOG(INFO) << "Load model from file.";
} }
Build(model_path, Build(model_path,
model_file, model_file,
param_file, param_file,
...@@ -296,10 +298,10 @@ void Predictor::Build(const std::string &model_path, ...@@ -296,10 +298,10 @@ void Predictor::Build(const std::string &model_path,
Build(program_desc_, valid_places, passes); Build(program_desc_, valid_places, passes);
} }
void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc, void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &program_desc,
const std::vector<Place> &valid_places, const std::vector<Place> &valid_places,
const std::vector<std::string> &passes) { const std::vector<std::string> &passes) {
program_desc_ = desc; program_desc_ = program_desc;
// `inner_places` is used to optimize passes // `inner_places` is used to optimize passes
std::vector<Place> inner_places = valid_places; std::vector<Place> inner_places = valid_places;
for (auto &valid_place : valid_places) { for (auto &valid_place : valid_places) {
...@@ -336,7 +338,7 @@ void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc, ...@@ -336,7 +338,7 @@ void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc,
Place{TARGET(kARM), PRECISION(kInt8)}); Place{TARGET(kARM), PRECISION(kInt8)});
} }
Program program(*desc.get(), scope_, inner_places); Program program(program_desc_, scope_, inner_places);
valid_places_ = inner_places; valid_places_ = inner_places;
core::KernelPickFactor factor; core::KernelPickFactor factor;
......
...@@ -58,13 +58,12 @@ class LITE_API Predictor { ...@@ -58,13 +58,12 @@ class LITE_API Predictor {
// Create a predictor with the weight variable scope set. // Create a predictor with the weight variable scope set.
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope) explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {} : scope_(root_scope) {}
Predictor(const std::shared_ptr<cpp::ProgramDesc>& desc, Predictor(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root_scope,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {}) const std::vector<std::string>& vars_to_clone = {})
: program_desc_(desc), scope_(root) { : program_desc_(program_desc), scope_(root_scope) {
Program program(*desc.get(), scope_, valid_places, var_names); Program program(program_desc_, scope_, valid_places, vars_to_clone);
// TODO(wilber): rethink a new way to associate config and passes.
optimizer_ = Optimizer(std::move(program), valid_places); optimizer_ = Optimizer(std::move(program), valid_places);
exec_scope_ = optimizer_.exec_scope(); exec_scope_ = optimizer_.exec_scope();
valid_places_ = valid_places; valid_places_ = valid_places;
...@@ -86,30 +85,28 @@ class LITE_API Predictor { ...@@ -86,30 +85,28 @@ class LITE_API Predictor {
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool memory_from_memory = false); bool memory_from_memory = false);
void Build(const std::shared_ptr<cpp::ProgramDesc>& desc, void Build(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}); const std::vector<std::string>& passes = {});
std::shared_ptr<Predictor> Clone() const { std::shared_ptr<Predictor> Clone() const {
auto predictor = return std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
return predictor;
} }
std::shared_ptr<Predictor> Clone( std::shared_ptr<Predictor> Clone(
const std::vector<std::string>& var_names) const { const std::vector<std::string>& vars_to_clone) const {
CHECK(program_desc_) << "Both program and scope of current predicotr " CHECK(program_desc_) << "Both program and scope of current predicotr "
"should be not be nullptr in Clone mode."; "should be not be nullptr in Clone mode.";
CHECK(scope_) << "Both program and scope of current predicotr should be " CHECK(scope_) << "Both program and scope of current predicotr should be "
"not be nullptr in Clone mode."; "not be nullptr in Clone mode.";
auto predictor = std::make_shared<Predictor>( auto predictor = std::make_shared<Predictor>(
program_desc_, scope_, valid_places_, var_names); program_desc_, scope_, valid_places_, vars_to_clone);
for (auto i : var_names) { for (auto var_name : vars_to_clone) {
predictor->exec_scope_->LocalVar(i); predictor->exec_scope_->LocalVar(var_name);
auto* tensor = predictor->scope_->Var(i)->GetMutable<lite::Tensor>(); auto* tensor = predictor->scope_->Var(var_name)->GetMutable<Tensor>();
auto* sub_tensor = auto* sub_tensor =
predictor->exec_scope_->Var(i)->GetMutable<lite::Tensor>(); predictor->exec_scope_->Var(var_name)->GetMutable<Tensor>();
sub_tensor->CopyDataFrom(*tensor); sub_tensor->CopyDataFrom(*tensor);
} }
return predictor; return predictor;
...@@ -147,6 +144,7 @@ class LITE_API Predictor { ...@@ -147,6 +144,7 @@ class LITE_API Predictor {
// get a const tensor according to its name // get a const tensor according to its name
const lite::Tensor* GetTensor(const std::string& name) const; const lite::Tensor* GetTensor(const std::string& name) const;
const RuntimeProgram& runtime_program() const; const RuntimeProgram& runtime_program() const;
Scope* scope() { return scope_.get(); }
// This method is disabled in mobile, for unnecessary dependencies required. // This method is disabled in mobile, for unnecessary dependencies required.
void SaveModel( void SaveModel(
......
...@@ -75,8 +75,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -75,8 +75,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
mode_ = config.power_mode(); mode_ = config.power_mode();
threads_ = config.threads(); threads_ = config.threads();
#ifdef LITE_WITH_NPU #ifdef LITE_WITH_NPU
// Store the model-level configuration into scope for kernels, and use
// exe_scope to store the execution-level configuration
Context<TargetType::kNPU>::SetSubgraphModelCacheDir( Context<TargetType::kNPU>::SetSubgraphModelCacheDir(
config.subgraph_model_cache_dir()); raw_predictor_->scope(), config.subgraph_model_cache_dir());
#endif #endif
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ #if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL) !(defined LITE_ON_MODEL_OPTIMIZE_TOOL)
......
...@@ -22,16 +22,16 @@ namespace lite { ...@@ -22,16 +22,16 @@ namespace lite {
void LightPredictor::Build(const std::string& lite_model_file, void LightPredictor::Build(const std::string& lite_model_file,
bool model_from_memory) { bool model_from_memory) {
if (model_from_memory) { if (model_from_memory) {
LoadModelNaiveFromMemory(lite_model_file, scope_.get(), &cpp_program_desc_); LoadModelNaiveFromMemory(
lite_model_file, scope_.get(), program_desc_.get());
} else { } else {
LoadModelNaiveFromFile(lite_model_file, scope_.get(), &cpp_program_desc_); LoadModelNaiveFromFile(lite_model_file, scope_.get(), program_desc_.get());
} }
// For weight quantization of post training, load the int8/16 weights // For weight quantization of post training, load the int8/16 weights
// for optimized model, and dequant it to fp32. // for optimized model, and dequant it to fp32.
DequantizeWeight(); DequantizeWeight();
BuildRuntimeProgram(program_desc_);
BuildRuntimeProgram(cpp_program_desc_);
PrepareFeedFetch(); PrepareFeedFetch();
} }
...@@ -43,15 +43,15 @@ void LightPredictor::Build(const std::string& model_dir, ...@@ -43,15 +43,15 @@ void LightPredictor::Build(const std::string& model_dir,
switch (model_type) { switch (model_type) {
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
LoadModelPb(model_dir, "", "", scope_.get(), &cpp_program_desc_); LoadModelPb(model_dir, "", "", scope_.get(), program_desc_.get());
break; break;
#endif #endif
case lite_api::LiteModelType::kNaiveBuffer: { case lite_api::LiteModelType::kNaiveBuffer: {
if (model_from_memory) { if (model_from_memory) {
LoadModelNaiveFromMemory( LoadModelNaiveFromMemory(
model_buffer, param_buffer, scope_.get(), &cpp_program_desc_); model_buffer, param_buffer, scope_.get(), program_desc_.get());
} else { } else {
LoadModelNaive(model_dir, scope_.get(), &cpp_program_desc_); LoadModelNaive(model_dir, scope_.get(), program_desc_.get());
} }
break; break;
} }
...@@ -60,7 +60,7 @@ void LightPredictor::Build(const std::string& model_dir, ...@@ -60,7 +60,7 @@ void LightPredictor::Build(const std::string& model_dir,
} }
DequantizeWeight(); DequantizeWeight();
BuildRuntimeProgram(cpp_program_desc_); BuildRuntimeProgram(program_desc_);
PrepareFeedFetch(); PrepareFeedFetch();
} }
...@@ -109,15 +109,17 @@ std::vector<std::string> LightPredictor::GetOutputNames() { ...@@ -109,15 +109,17 @@ std::vector<std::string> LightPredictor::GetOutputNames() {
} }
// append the names of inputs and outputs into input_names_ and output_names_ // append the names of inputs and outputs into input_names_ and output_names_
void LightPredictor::PrepareFeedFetch() { void LightPredictor::PrepareFeedFetch() {
auto current_block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(0); std::vector<const cpp::OpDesc*> feeds;
std::vector<cpp::OpDesc*> feeds; std::vector<const cpp::OpDesc*> fetchs;
std::vector<cpp::OpDesc*> fetchs; std::shared_ptr<const cpp::ProgramDesc> program_desc = program_desc_;
for (size_t i = 0; i < current_block->OpsSize(); i++) { auto main_block = program_desc->GetBlock<cpp::BlockDesc>(kRootBlockIdx);
auto op = current_block->GetOp<cpp::OpDesc>(i); auto op_size = main_block->OpsSize();
if (op->Type() == "feed") { for (size_t op_idx = 0; op_idx < op_size; ++op_idx) {
feeds.push_back(op); auto op_desc = main_block->GetOp<cpp::OpDesc>(op_idx);
} else if (op->Type() == "fetch") { if (op_desc->Type() == "feed") {
fetchs.push_back(op); feeds.push_back(op_desc);
} else if (op_desc->Type() == "fetch") {
fetchs.push_back(op_desc);
} }
} }
input_names_.resize(feeds.size()); input_names_.resize(feeds.size());
...@@ -132,54 +134,35 @@ void LightPredictor::PrepareFeedFetch() { ...@@ -132,54 +134,35 @@ void LightPredictor::PrepareFeedFetch() {
} }
} }
void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { void LightPredictor::BuildRuntimeProgram(
std::vector<Instruction> insts; const std::shared_ptr<const cpp::ProgramDesc>& program_desc) {
// 1. Create op first auto* exe_scope = &scope_->NewScope();
Program program(prog, scope_, {}); // Prepare workspace
scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
// 2. Create Instructs scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
#ifdef LITE_WITH_OPENCL CHECK(program_desc);
using OpenCLContext = Context<TargetType::kOpenCL>; auto block_size = program_desc->BlocksSize();
std::unique_ptr<KernelContext> local_ctx(new KernelContext()); CHECK(block_size);
local_ctx->As<OpenCLContext>().InitOnce(); for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
#endif auto block_desc = program_desc->GetBlock<cpp::BlockDesc>(block_idx);
auto var_size = block_desc->VarsSize();
// Create the kernels of the target places, and filter out the specific for (size_t var_idx = 0; var_idx < var_size; ++var_idx) {
// kernel with the target alias. auto var_desc = block_desc->GetVar<cpp::VarDesc>(var_idx);
for (auto& op : program.ops()) { if (!var_desc->Persistable()) {
auto kernel_type = op->op_info()->GetAttr<std::string>(kKernelTypeAttr); exe_scope->Var(var_desc->Name());
std::string op_type, alias;
Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
auto kernels = op->CreateKernels({place});
// filter out a kernel
auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& it) {
return it->alias() == alias;
});
CHECK(it != kernels.end());
#ifdef LITE_WITH_OPENCL
if ((*it)->target() == TARGET(kOpenCL)) {
std::unique_ptr<KernelContext> ctx(new KernelContext());
(*local_ctx).As<OpenCLContext>().CopySharedTo(&ctx->As<OpenCLContext>());
(*it)->SetContext(std::move(ctx));
} else { } else {
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue;
scope_->Var(var_desc->Name());
} }
#else
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
#endif
insts.emplace_back(op, std::move(*it));
} }
program_.reset(new RuntimeProgram(std::move(insts))); }
// Only extracting the ops and generate the runtime program from the main
CHECK(program.exec_scope()); // block desc
program_->set_exec_scope(program.exec_scope()); program_.reset(new RuntimeProgram(program_desc, exe_scope, kRootBlockIdx));
} }
void LightPredictor::DequantizeWeight() { void LightPredictor::DequantizeWeight() {
std::shared_ptr<const cpp::ProgramDesc> program_desc = program_desc_;
#define PROCESS_CONV2D_DATA() \ #define PROCESS_CONV2D_DATA() \
for (int64_t i = 0; i < ch; ++i) { \ for (int64_t i = 0; i < ch; ++i) { \
for (int64_t j = 0; j < offset; ++j) { \ for (int64_t j = 0; j < offset; ++j) { \
...@@ -205,10 +188,9 @@ void LightPredictor::DequantizeWeight() { ...@@ -205,10 +188,9 @@ void LightPredictor::DequantizeWeight() {
} }
return result; return result;
}; };
Tensor tmp_tensor; Tensor tmp_tensor;
for (size_t i = 0; i < cpp_program_desc_.BlocksSize(); i++) { for (size_t i = 0; i < program_desc->BlocksSize(); i++) {
auto* block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(i); auto* block = program_desc->GetBlock<cpp::BlockDesc>(i);
for (size_t k = 0; k < block->OpsSize(); ++k) { for (size_t k = 0; k < block->OpsSize(); ++k) {
auto* op_desc = block->GetOp<cpp::OpDesc>(k); auto* op_desc = block->GetOp<cpp::OpDesc>(k);
if (is_weight_quantized_op(op_desc)) { if (is_weight_quantized_op(op_desc)) {
......
...@@ -46,6 +46,7 @@ class LITE_API LightPredictor { ...@@ -46,6 +46,7 @@ class LITE_API LightPredictor {
LightPredictor(const std::string& lite_model_file, LightPredictor(const std::string& lite_model_file,
bool model_from_memory = false) { bool model_from_memory = false) {
scope_ = std::make_shared<Scope>(); scope_ = std::make_shared<Scope>();
program_desc_ = std::make_shared<cpp::ProgramDesc>();
Build(lite_model_file, model_from_memory); Build(lite_model_file, model_from_memory);
} }
...@@ -57,6 +58,7 @@ class LITE_API LightPredictor { ...@@ -57,6 +58,7 @@ class LITE_API LightPredictor {
lite_api::LiteModelType model_type = lite_api::LiteModelType model_type =
lite_api::LiteModelType::kNaiveBuffer) { lite_api::LiteModelType::kNaiveBuffer) {
scope_ = std::make_shared<Scope>(); scope_ = std::make_shared<Scope>();
program_desc_ = std::make_shared<cpp::ProgramDesc>();
Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory); Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory);
} }
...@@ -78,6 +80,7 @@ class LITE_API LightPredictor { ...@@ -78,6 +80,7 @@ class LITE_API LightPredictor {
std::vector<std::string> GetInputNames(); std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames(); std::vector<std::string> GetOutputNames();
void PrepareFeedFetch(); void PrepareFeedFetch();
Scope* scope() { return scope_.get(); }
private: private:
void Build(const std::string& lite_model_file, void Build(const std::string& lite_model_file,
...@@ -91,14 +94,15 @@ class LITE_API LightPredictor { ...@@ -91,14 +94,15 @@ class LITE_API LightPredictor {
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool model_from_memory = false); bool model_from_memory = false);
void BuildRuntimeProgram(const cpp::ProgramDesc& prog); void BuildRuntimeProgram(
const std::shared_ptr<const cpp::ProgramDesc>& program_desc);
void DequantizeWeight(); void DequantizeWeight();
private: private:
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_; std::unique_ptr<RuntimeProgram> program_;
cpp::ProgramDesc cpp_program_desc_; std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
}; };
......
...@@ -38,8 +38,10 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { ...@@ -38,8 +38,10 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) {
threads_ = config.threads(); threads_ = config.threads();
#ifdef LITE_WITH_NPU #ifdef LITE_WITH_NPU
// Store the model-level configuration into scope for kernels, and use
// exe_scope to store the execution-level configuration
Context<TargetType::kNPU>::SetSubgraphModelCacheDir( Context<TargetType::kNPU>::SetSubgraphModelCacheDir(
config.subgraph_model_cache_dir()); raw_predictor_->scope(), config.subgraph_model_cache_dir());
#endif #endif
} }
......
...@@ -28,6 +28,7 @@ USE_MIR_PASS(graph_visualize_pass); ...@@ -28,6 +28,7 @@ USE_MIR_PASS(graph_visualize_pass);
USE_MIR_PASS(remove_tf_redundant_ops_pass); USE_MIR_PASS(remove_tf_redundant_ops_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_conv_conv_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass); USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass); USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
...@@ -53,6 +54,7 @@ USE_MIR_PASS(mlu_postprocess_pass); ...@@ -53,6 +54,7 @@ USE_MIR_PASS(mlu_postprocess_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass); USE_MIR_PASS(weight_quantization_preprocess_pass);
USE_MIR_PASS(apu_subgraph_pass); USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass)
USE_MIR_PASS(lite_scale_activation_fuse_pass); USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass); USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass);
......
...@@ -234,7 +234,7 @@ void beam_search(const Tensor *pre_ids, ...@@ -234,7 +234,7 @@ void beam_search(const Tensor *pre_ids,
selected_ids->Resize(dims); selected_ids->Resize(dims);
selected_scores->Resize(dims); selected_scores->Resize(dims);
if (parent_idx) { if (parent_idx) {
parent_idx->Resize(dims); parent_idx->Resize({static_cast<int64_t>(num_instances)});
} }
auto *selected_ids_data = selected_ids->mutable_data<int64_t>(); auto *selected_ids_data = selected_ids->mutable_data<int64_t>();
auto *selected_scores_data = selected_scores->mutable_data<float>(); auto *selected_scores_data = selected_scores->mutable_data<float>();
......
...@@ -139,6 +139,71 @@ static bool conv_trans_weights_numc(const dtype* din, ...@@ -139,6 +139,71 @@ static bool conv_trans_weights_numc(const dtype* din,
} }
return true; return true;
} }
// for example: m = 4, n = 4
// din = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9 , 10 ,11], [12, 13, 14, 15]]
// dout = [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]
/*
m = 8 n = 8: 0 1 2 3 4 5 6 7 0 8 16 24 32 40 48 56
16 17 18 19 20 21 22 23 2 10 18 26 34 42 50 58
24 25 26 27 28 29 30 31 3 11 19 27 35 43 51 59
32 33 34 35 36 37 38 39 4 12 20 28 36 44 52 60 ...
}
}
*/
template <typename Dtype>
void local_transpose(const Dtype* din, Dtype* dout, int m, int n) {
// n % 4 == 0 && m % 4 == 0
// n * m ==> n * m data trans
int offset_m = m << 2;
const Dtype* din_ptr = din;
Dtype* dout_ptr = dout;
for (int i = 0; i < n; i += 4) {
Dtype* out_ptr0 = dout_ptr;
Dtype* out_ptr1 = dout_ptr + m;
Dtype* out_ptr2 = out_ptr1 + m;
Dtype* out_ptr3 = out_ptr2 + m;
const Dtype* in_ptr0 = din_ptr;
const Dtype* in_ptr1 = din_ptr + m;
const Dtype* in_ptr2 = in_ptr1 + m;
const Dtype* in_ptr3 = in_ptr2 + m;
for (int j = 0; j < m; j += 4) {
float32x4_t vin0 = vld1q_f32(in_ptr0);
float32x4_t vin1 = vld1q_f32(in_ptr1);
float32x4_t vin2 = vld1q_f32(in_ptr2);
float32x4_t vin3 = vld1q_f32(in_ptr3);
// a00 b00 a02 b02 a01 b01 a03 b03
float32x4x2_t tmp0 = vtrnq_f32(vin0, vin1);
// c00 d00 c02 d02 c01 d01 c03 d03
float32x4x2_t tmp2 = vtrnq_f32(vin2, vin3);
in_ptr0 = in_ptr3 + m;
in_ptr1 = in_ptr3 + 2 * m;
float tmp_val1 = tmp0.val[0][2];
float tmp_val2 = tmp0.val[0][3];
tmp0.val[0][2] = tmp2.val[0][0];
tmp0.val[0][3] = tmp2.val[0][1];
float tmp_val3 = tmp0.val[1][2];
float tmp_val4 = tmp0.val[1][3];
tmp2.val[0][0] = tmp_val1;
tmp2.val[0][1] = tmp_val2;
tmp0.val[1][2] = tmp2.val[1][0];
tmp0.val[1][3] = tmp2.val[1][1];
tmp2.val[1][0] = tmp_val3;
tmp2.val[1][1] = tmp_val4;
in_ptr2 = in_ptr1 + m;
in_ptr3 = in_ptr1 + 2 * m;
vst1q_f32(out_ptr0, tmp0.val[0]);
vst1q_f32(out_ptr1, tmp0.val[1]);
out_ptr0 += 4;
out_ptr1 += 4;
vst1q_f32(out_ptr2, tmp2.val[0]);
vst1q_f32(out_ptr3, tmp2.val[1]);
out_ptr2 += 4;
out_ptr3 += 4;
}
dout_ptr += offset_m;
din_ptr += 4;
}
}
template <typename Dtype> template <typename Dtype>
void transpose(const Dtype* din, Dtype* dout, int m, int n) { void transpose(const Dtype* din, Dtype* dout, int m, int n) {
// nxm == mxn // nxm == mxn
......
...@@ -747,6 +747,16 @@ void elementwise_mul<int>(const int* dinx, ...@@ -747,6 +747,16 @@ void elementwise_mul<int>(const int* dinx,
} }
} }
template <>
void elementwise_mul<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int num) {
for (int i = 0; i < num; i++) {
dout[i] = dinx[i] * diny[i];
}
}
template <> template <>
void elementwise_mul_relu<float>(const float* dinx, void elementwise_mul_relu<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -801,6 +811,17 @@ void elementwise_mul_relu<float>(const float* dinx, ...@@ -801,6 +811,17 @@ void elementwise_mul_relu<float>(const float* dinx,
} }
} }
template <>
void elementwise_mul_relu<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int num) {
for (int i = 0; i < num; i++) {
int64_t tmp = dinx[i] * diny[i];
dout[i] = tmp > 0 ? tmp : 0;
}
}
template <> template <>
void elementwise_mul_broadcast<float>(const float* dinx, void elementwise_mul_broadcast<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -935,6 +956,29 @@ void elementwise_mul_broadcast<int>(const int* dinx, ...@@ -935,6 +956,29 @@ void elementwise_mul_broadcast<int>(const int* dinx,
} }
} }
template <>
void elementwise_mul_broadcast<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int batch,
int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const int64_t* dinx_ptr = dinx + offset;
const int64_t diny_data = diny[j];
int64_t* dout_ptr = dout + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *dinx_ptr * diny_data;
dout_ptr++;
dinx_ptr++;
}
}
}
}
template <> template <>
void elementwise_mul_relu_broadcast<float>(const float* dinx, void elementwise_mul_relu_broadcast<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -1014,6 +1058,30 @@ void elementwise_mul_relu_broadcast<float>(const float* dinx, ...@@ -1014,6 +1058,30 @@ void elementwise_mul_relu_broadcast<float>(const float* dinx,
} }
} }
template <>
void elementwise_mul_relu_broadcast<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int batch,
int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const int64_t* dinx_ptr = dinx + offset;
const int64_t diny_data = diny[j];
int64_t* dout_ptr = dout + offset;
for (int k = 0; k < num; ++k) {
int64_t tmp = *dinx_ptr * diny_data;
*dout_ptr = tmp > 0 ? tmp : 0;
dout_ptr++;
dinx_ptr++;
}
}
}
}
template <> template <>
void elementwise_max<float>(const float* dinx, void elementwise_max<float>(const float* dinx,
const float* diny, const float* diny,
......
...@@ -21,7 +21,7 @@ namespace lite { ...@@ -21,7 +21,7 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
const int MALLOC_ALIGN = 64; const int MALLOC_ALIGN = 16;
void* fast_malloc(size_t size) { void* fast_malloc(size_t size) {
size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; size_t offset = sizeof(void*) + MALLOC_ALIGN - 1;
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
const int MALLOC_ALIGN = 64; const int MALLOC_ALIGN = 16;
void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) { void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) {
size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; size_t offset = sizeof(void*) + MALLOC_ALIGN - 1;
...@@ -30,7 +30,6 @@ void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) { ...@@ -30,7 +30,6 @@ void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) {
void* r = reinterpret_cast<void*>(reinterpret_cast<size_t>(p + offset) & void* r = reinterpret_cast<void*>(reinterpret_cast<size_t>(p + offset) &
(~(MALLOC_ALIGN - 1))); (~(MALLOC_ALIGN - 1)));
static_cast<void**>(r)[-1] = p; static_cast<void**>(r)[-1] = p;
memset(r, 0, size);
return r; return r;
} }
void TargetWrapper<TARGET(kHost)>::Free(void* ptr) { void TargetWrapper<TARGET(kHost)>::Free(void* ptr) {
......
...@@ -33,7 +33,7 @@ std::shared_ptr<hiai::AiModelMngerClient> Device::Load( ...@@ -33,7 +33,7 @@ std::shared_ptr<hiai::AiModelMngerClient> Device::Load(
// Check HiAI DDK version // Check HiAI DDK version
const char* ddk_version = model_client->GetVersion(); const char* ddk_version = model_client->GetVersion();
if (ddk_version) { if (ddk_version) {
LOG(INFO) << "[NPU] HiAI DDK version: " << ddk_version; VLOG(3) << "[NPU] HiAI DDK version: " << ddk_version;
} else { } else {
LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!"; LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!";
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/backends/xpu/target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -82,8 +82,8 @@ void DumpXPUMem(const T* ptr, ...@@ -82,8 +82,8 @@ void DumpXPUMem(const T* ptr,
size_t item_per_line = 30) { size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride; size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> cpu_mem(new T[len]); std::unique_ptr<T[]> cpu_mem(new T[len]);
xpu_memcpy( XPU_CALL(xpu_memcpy(
cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST); cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST));
std::unique_ptr<T[]> after_stride(new T[after_stride_len]); std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) { for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = cpu_mem[i * stride]; after_stride[i] = cpu_mem[i * stride];
......
...@@ -19,11 +19,11 @@ namespace lite { ...@@ -19,11 +19,11 @@ namespace lite {
void* TargetWrapperXPU::Malloc(size_t size) { void* TargetWrapperXPU::Malloc(size_t size) {
void* ptr{nullptr}; void* ptr{nullptr};
xpu_malloc(&ptr, size); XPU_CALL(xpu_malloc(&ptr, size));
return ptr; return ptr;
} }
void TargetWrapperXPU::Free(void* ptr) { xpu_free(ptr); } void TargetWrapperXPU::Free(void* ptr) { XPU_CALL(xpu_free(ptr)); }
void TargetWrapperXPU::MemcpySync(void* dst, void TargetWrapperXPU::MemcpySync(void* dst,
const void* src, const void* src,
...@@ -31,10 +31,10 @@ void TargetWrapperXPU::MemcpySync(void* dst, ...@@ -31,10 +31,10 @@ void TargetWrapperXPU::MemcpySync(void* dst,
IoDirection dir) { IoDirection dir) {
switch (dir) { switch (dir) {
case IoDirection::HtoD: case IoDirection::HtoD:
xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE); XPU_CALL(xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE));
break; break;
case IoDirection::DtoH: case IoDirection::DtoH:
xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST); XPU_CALL(xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST));
break; break;
default: default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir); LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
...@@ -49,7 +49,7 @@ XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size, ...@@ -49,7 +49,7 @@ XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size,
} else { } else {
ptr = TargetWrapperXPU::Malloc(size); ptr = TargetWrapperXPU::Malloc(size);
} }
CHECK(ptr != nullptr); CHECK(ptr != nullptr) << "size = " << size << ", use_l3 = " << use_l3;
return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3)); return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3));
} }
......
...@@ -16,11 +16,23 @@ ...@@ -16,11 +16,23 @@
#include <memory> // std::unique_ptr #include <memory> // std::unique_ptr
#include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free #include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h" // TargetWrapper
#include "lite/utils/cp_logging.h" // CHECK_EQ
#define XPU_CALL(func) \
{ \
auto e = (func); \
CHECK_EQ(e, 0) << "XPU: (" << #func << ") returns " << e; \
}
namespace paddle { namespace paddle {
namespace lite { namespace lite {
// MAX(lod.size()) = 64
const int XPU_MAX_LOD_SIZE = 64;
// MAX(lod[i + 1] - lod[i]) = 512
const int XPU_MAX_LOD_SEQ_LEN = 512;
using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>; using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;
struct XPUScratchPad { struct XPUScratchPad {
...@@ -33,7 +45,7 @@ struct XPUScratchPad { ...@@ -33,7 +45,7 @@ struct XPUScratchPad {
struct XPUScratchPadDeleter { struct XPUScratchPadDeleter {
void operator()(XPUScratchPad* sp) const { void operator()(XPUScratchPad* sp) const {
if (!sp->is_l3_) { if (!sp->is_l3_) {
xpu_free(sp->addr_); XPU_CALL(xpu_free(sp->addr_));
} }
delete sp; delete sp;
} }
...@@ -55,7 +67,7 @@ class TargetWrapper<TARGET(kXPU)> { ...@@ -55,7 +67,7 @@ class TargetWrapper<TARGET(kXPU)> {
size_t size, size_t size,
IoDirection dir); IoDirection dir);
static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = true); static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = false);
static xdnn::Context* GetRawContext() { static xdnn::Context* GetRawContext() {
if (tls_raw_ctx_ == nullptr) { if (tls_raw_ctx_ == nullptr) {
...@@ -77,11 +89,10 @@ class TargetWrapper<TARGET(kXPU)> { ...@@ -77,11 +89,10 @@ class TargetWrapper<TARGET(kXPU)> {
static void SetDev(int dev_no = 0) { static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV"); const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) { if (dev_env) {
xpu_set_device(atoi(dev_env)); dev_no = atoi(dev_env);
return;
} }
xpu_set_device(dev_no); XPU_CALL(xpu_set_device(dev_no));
} }
static std::string multi_encoder_precision; // NOLINT static std::string multi_encoder_precision; // NOLINT
......
...@@ -32,25 +32,27 @@ void TestCase::CreateInstruction() { ...@@ -32,25 +32,27 @@ void TestCase::CreateInstruction() {
#endif #endif
if (enable_subgraph_op) { if (enable_subgraph_op) {
// Create a new block desc to wrap the original op desc // Create a new block desc to wrap the original op desc
auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
int sub_block_idx = 0; int sub_block_idx = 0;
auto sub_block_desc = new cpp::BlockDesc(); auto sub_block_desc = sub_program_desc->AddBlock<cpp::BlockDesc>();
sub_block_desc->ClearOps(); sub_block_desc->ClearOps();
sub_block_desc->ClearVars(); sub_block_desc->ClearVars();
auto sub_block_op_desc = sub_block_desc->AddOp<cpp::OpDesc>(); auto sub_op_desc = sub_block_desc->AddOp<cpp::OpDesc>();
*sub_block_op_desc = *op_desc_; *sub_op_desc = *op_desc_;
// Add the block desc into the subgraph op which used to replace the // Add the block desc into the subgraph op which used to replace the
// original op // original op
op_desc_.reset(new cpp::OpDesc()); op_desc_.reset(new cpp::OpDesc());
op_desc_->SetType("subgraph"); op_desc_->SetType("subgraph");
op_desc_->SetAttr<int32_t>("sub_block", sub_block_idx); op_desc_->SetAttr<int32_t>("sub_block", sub_block_idx);
auto in_names = sub_block_op_desc->input_vars(); auto in_names = sub_op_desc->input_vars();
auto out_names = sub_block_op_desc->output_vars(); auto out_names = sub_op_desc->output_vars();
op_desc_->SetInput("Inputs", in_names); op_desc_->SetInput("Inputs", in_names);
op_desc_->SetOutput("Outputs", out_names); op_desc_->SetOutput("Outputs", out_names);
op_desc_->SetAttr<std::vector<std::string>>("input_data_names", in_names); op_desc_->SetAttr<std::vector<std::string>>("input_data_names", in_names);
op_desc_->SetAttr<std::vector<std::string>>("output_data_names", out_names); op_desc_->SetAttr<std::vector<std::string>>("output_data_names", out_names);
op = LiteOpRegistry::Global().Create(op_desc().Type()); op = LiteOpRegistry::Global().Create(op_desc().Type());
static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock(sub_block_desc); static_cast<operators::SubgraphOp*>(op.get())->SetProgramDesc(
sub_program_desc);
} else { } else {
op = LiteOpRegistry::Global().Create(op_desc().Type()); op = LiteOpRegistry::Global().Create(op_desc().Type());
} }
...@@ -60,7 +62,7 @@ void TestCase::CreateInstruction() { ...@@ -60,7 +62,7 @@ void TestCase::CreateInstruction() {
// filter out the target kernel // filter out the target kernel
CHECK(!kernels.empty()) << "No kernel found for place " CHECK(!kernels.empty()) << "No kernel found for place "
<< place_.DebugString(); << place_.DebugString();
auto it = std::remove_if( auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& k) { kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& k) {
return k->alias() == alias_; return k->alias() == alias_;
}); });
...@@ -234,19 +236,6 @@ bool TestCase::CheckPrecision(const std::string& var_name, ...@@ -234,19 +236,6 @@ bool TestCase::CheckPrecision(const std::string& var_name,
return success; return success;
} }
TestCase::~TestCase() {
if (op_desc_->Type() == "subgraph") {
// Release the subblock desc of Subgraph op
auto subgraph_op = const_cast<operators::SubgraphOp*>(
static_cast<const operators::SubgraphOp*>(instruction_->op()));
CHECK(subgraph_op);
auto sub_block_desc = subgraph_op->GetSubBlock();
if (sub_block_desc) {
delete sub_block_desc;
}
}
}
} // namespace arena } // namespace arena
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -46,7 +46,7 @@ class TestCase { ...@@ -46,7 +46,7 @@ class TestCase {
base_scope_(new Scope) { base_scope_(new Scope) {
ctx_ = ContextScheduler::Global().NewContext(place_.target); ctx_ = ContextScheduler::Global().NewContext(place_.target);
} }
virtual ~TestCase(); virtual ~TestCase() {}
void Prepare() { void Prepare() {
PrepareData(); PrepareData();
......
...@@ -17,10 +17,6 @@ ...@@ -17,10 +17,6 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
#ifdef LITE_WITH_NPU
std::string Context<TargetType::kNPU>::subgraph_model_cache_dir_{""}; // NOLINT
#endif
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
int Context<TargetType::kMLU>::next_queue_id_{0}; int Context<TargetType::kMLU>::next_queue_id_{0};
std::map<int, int> Context<TargetType::kMLU>::queue_id_map_; std::map<int, int> Context<TargetType::kMLU>::queue_id_map_;
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/device_info.h" #include "lite/core/device_info.h"
#include "lite/core/scope.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
...@@ -84,15 +85,19 @@ class Context<TargetType::kNPU> { ...@@ -84,15 +85,19 @@ class Context<TargetType::kNPU> {
NPUContext& operator=(const NPUContext& ctx) {} NPUContext& operator=(const NPUContext& ctx) {}
std::string name() const { return "NPUContext"; } std::string name() const { return "NPUContext"; }
static void SetSubgraphModelCacheDir(std::string subgraph_model_cache_dir) { static void SetSubgraphModelCacheDir(Scope* scope,
subgraph_model_cache_dir_ = subgraph_model_cache_dir; std::string subgraph_model_cache_dir) {
auto var = scope->Var("SUBGRAPH_MODEL_CACHE_DIR");
CHECK(var);
auto data = var->GetMutable<std::string>();
CHECK(data);
*data = subgraph_model_cache_dir;
} }
static std::string SubgraphModelCacheDir() { static std::string SubgraphModelCacheDir(Scope* scope) {
return subgraph_model_cache_dir_; auto var = scope->FindVar("SUBGRAPH_MODEL_CACHE_DIR");
if (!var) return "";
return var->Get<std::string>();
} }
private:
static std::string subgraph_model_cache_dir_;
}; };
#endif #endif
......
...@@ -18,6 +18,7 @@ lite_cc_library(mir_passes ...@@ -18,6 +18,7 @@ lite_cc_library(mir_passes
fusion/conv_activation_fuse_pass.cc fusion/conv_activation_fuse_pass.cc
fusion/var_conv_2d_activation_fuse_pass.cc fusion/var_conv_2d_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc fusion/conv_bn_fuse_pass.cc
fusion/conv_conv_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc fusion/sequence_pool_concat_fuse_pass.cc
...@@ -32,6 +33,7 @@ lite_cc_library(mir_passes ...@@ -32,6 +33,7 @@ lite_cc_library(mir_passes
elimination/identity_dropout_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc
elimination/remove_tf_redundant_ops_pass.cc elimination/remove_tf_redundant_ops_pass.cc
elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.cc
static_kernel_pick_pass.cc static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_cast_pass.cc type_target_cast_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
// Remove all of the unused nodes from the contorl flow op and update the inputs
// and outputs of the op info The unused nodes are defined as the nodes which
// are only linked to the control flow op nodes but nerver linked to the other
// op nodes.
//
// For example:
// graph[0]: main block
// in_x
// in_f | in_z(unused node)
// \ | /
// \ | /
// in_w ------- while ------- in_y(unused_node)
// / |
// / |
// (unused node)out_y |
// out_x
//
// graph[1]: sub block
// in_x
// |
// |
// conv2d----in_f
// |
// |
// fc ------in_w
// |
// |
// softmax
// |
// |
// out_x
//
// After the pass is applied:
// in_x
// in_f |
// \ |
// \ |
// in_w ------- while
// |
// |
// |
// out_x
// Remove the var node from var2rm if it is recursively referred to any op in
// the subblock
void CollectUnusedInputOutputNodes(
int block_idx,
std::vector<std::unique_ptr<mir::SSAGraph>>* graphs,
const std::unordered_set<std::string>& control_flow_op_types,
std::unordered_map<std::string, Node*>* in_vars2rm,
std::unordered_map<std::string, Node*>* out_vars2rm) {
auto block_size = graphs->size();
for (auto& op_node : (*graphs)[block_idx]->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().op_info();
auto op_type = op_info->Type();
if (control_flow_op_types.count(op_type)) {
int sub_block_idx = op_info->GetAttr<int32_t>("sub_block");
CHECK(block_idx >= 0 && block_idx < block_size);
CollectUnusedInputOutputNodes(sub_block_idx,
graphs,
control_flow_op_types,
in_vars2rm,
out_vars2rm);
} else {
for (auto& var_node : op_node->inlinks) {
auto& var_name = var_node->AsArg().name;
if (in_vars2rm->count(var_name)) {
in_vars2rm->erase(var_name);
}
}
for (auto& var_node : op_node->outlinks) {
auto& var_name = var_node->AsArg().name;
// Tensor array may be only used as the output vars in the sublock
if (in_vars2rm->count(var_name)) {
in_vars2rm->erase(var_name);
}
if (out_vars2rm->count(var_name)) {
out_vars2rm->erase(var_name);
}
}
}
}
}
// Remove the unused var nodes from the graph and update the op_info of the
// control flow op
void RemoveNodesFromGraphAndUpdateOpInfo(
SSAGraph* graph,
Node* op_node,
const std::unordered_map<std::string, Node*>& in_vars2rm,
const std::unordered_map<std::string, Node*>& out_vars2rm) {
auto op_info = op_node->AsStmt().mutable_op_info();
auto op_type = op_info->Type();
// Unlink the in_vars2rm and out_vars2rm from the control flow op node, and
// remove them if nerver used.
for (auto& var_node : in_vars2rm) {
VLOG(3) << "in var node '" << var_node.first << "' is unlinked to "
<< op_type;
RemoveDirectedLink(var_node.second, op_node);
}
for (auto& var_node : out_vars2rm) {
VLOG(3) << "out var node '" << var_node.first << "' is unlinked from "
<< op_type;
RemoveDirectedLink(op_node, var_node.second);
// Unlink from all of the out op nodes.
std::unordered_set<Node*> out_op_nodes;
for (auto* out_op_node : var_node.second->outlinks) {
if (!out_op_nodes.count(out_op_node)) {
out_op_nodes.insert(out_op_node);
}
}
for (auto* out_op_node : out_op_nodes) {
RemoveDirectedLink(var_node.second, out_op_node);
}
}
// Remove the unused nodes from the graph if their inlinks and outlinks are
// empty
std::unordered_set<const Node*> removed_var_nodes;
for (auto& var_node : in_vars2rm) {
if (var_node.second->inlinks.empty() && var_node.second->outlinks.empty() &&
!removed_var_nodes.count(var_node.second)) {
removed_var_nodes.insert(var_node.second);
graph->RemoveNode(var_node.second);
VLOG(3) << "in var node " << var_node.first << " is removed";
}
}
for (auto& var_node : out_vars2rm) {
if (var_node.second->inlinks.empty() && var_node.second->outlinks.empty() &&
!removed_var_nodes.count(var_node.second)) {
removed_var_nodes.insert(var_node.second);
graph->RemoveNode(var_node.second);
VLOG(3) << "out var node " << var_node.first << " is removed";
}
}
// Update the op info of the control flow op
for (auto& input : *op_info->mutable_inputs()) {
for (auto var = input.second.begin(); var != input.second.end();) {
if (in_vars2rm.count(*var)) {
var = input.second.erase(var);
} else {
++var;
}
}
}
for (auto& output : *op_info->mutable_outputs()) {
for (auto var = output.second.begin(); var != output.second.end();) {
if (out_vars2rm.count(*var)) {
var = output.second.erase(var);
} else {
++var;
}
}
}
}
void ControlFlowOpUnusedInputsAndOutputsEliminatePass::SetAllGraphs(
std::vector<std::unique_ptr<mir::SSAGraph>>* graphs) {
CHECK(graphs && !graphs->empty());
graphs_ = graphs;
}
void ControlFlowOpUnusedInputsAndOutputsEliminatePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
// Remove the unused input and output nodes from the control flow op nodes
// Which are only linked to the control flow op nodes but nerver linked to the
// other op nodes
const std::unordered_set<std::string> control_flow_op_types = {
"while", "conditional_block"};
auto block_size = graphs_->size();
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().mutable_op_info();
auto op_type = op_info->Type();
if (!control_flow_op_types.count(op_type)) continue;
int sub_block_idx = op_info->GetAttr<int32_t>("sub_block");
CHECK(sub_block_idx >= 0 && sub_block_idx < block_size);
// Initialize the unused nodes with all of the input and output nodes
std::unordered_map<std::string, Node *> in_vars2rm, out_vars2rm;
for (auto* var_node : op_node->inlinks) {
auto& var_name = var_node->AsArg().name;
if (!in_vars2rm.count(var_name)) {
in_vars2rm.insert(std::pair<std::string, Node*>(var_name, var_node));
}
}
for (auto* var_node : op_node->outlinks) {
auto& var_name = var_node->AsArg().name;
if (!out_vars2rm.count(var_name)) {
out_vars2rm.insert(std::pair<std::string, Node*>(var_name, var_node));
}
}
// Remove the nodes which used in subblock recursively, and the remaining
// nodes are the unused one.
CollectUnusedInputOutputNodes(sub_block_idx,
graphs_,
control_flow_op_types,
&in_vars2rm,
&out_vars2rm);
if (in_vars2rm.size() > 0 || out_vars2rm.size() > 0) {
// Remove the unused nodes from graph, and update the op info of the
// control flow op
RemoveNodesFromGraphAndUpdateOpInfo(
graph.get(), op_node, in_vars2rm, out_vars2rm);
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(
control_flow_op_unused_inputs_and_outputs_eliminate_pass,
paddle::lite::mir::ControlFlowOpUnusedInputsAndOutputsEliminatePass)
.BindTargets({TARGET(kNPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace mir {
class ControlFlowOpUnusedInputsAndOutputsEliminatePass : public mir::StmtPass {
public:
void Apply(const std::unique_ptr<SSAGraph> &graph) override;
void SetAllGraphs(std::vector<std::unique_ptr<mir::SSAGraph>> *graphs);
private:
std::vector<std::unique_ptr<mir::SSAGraph>> *graphs_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -16,6 +16,9 @@ lite_cc_library(fuse_var_conv_activation ...@@ -16,6 +16,9 @@ lite_cc_library(fuse_var_conv_activation
lite_cc_library(fuse_conv_bn lite_cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
lite_cc_library(fuse_conv_conv
SRCS conv_conv_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_elementwise_add_activation lite_cc_library(fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
...@@ -42,6 +45,7 @@ set(mir_fusers ...@@ -42,6 +45,7 @@ set(mir_fusers
fuse_conv_activation fuse_conv_activation
fuse_var_conv_activation fuse_var_conv_activation
fuse_conv_bn fuse_conv_bn
fuse_conv_conv
fuse_quant_dequant fuse_quant_dequant
fuse_elementwise_add_activation fuse_elementwise_add_activation
fuse_transpose_softmax_transpose fuse_transpose_softmax_transpose
......
...@@ -383,10 +383,10 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -383,10 +383,10 @@ class XPUSingleEncoderFuser : public FuseBase {
op_desc.SetAttr<std::string>("act_type", act_type_); op_desc.SetAttr<std::string>("act_type", act_type_);
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
auto* single_encoder_stmt = matched.at("q_mul")->stmt(); auto* single_encoder_stmt = matched.at("q_mul")->stmt();
fake_subgraph_op->Attach(op_desc, single_encoder_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, single_encoder_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(single_encoder_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(single_encoder_stmt->op()->valid_places());
......
...@@ -373,10 +373,10 @@ class XPUResNetCbamBlock0Fuser : public FuseBase { ...@@ -373,10 +373,10 @@ class XPUResNetCbamBlock0Fuser : public FuseBase {
auto block0_stmt = matched.at("left_conv1")->stmt(); auto block0_stmt = matched.at("left_conv1")->stmt();
// block0_stmt->ResetOp(op_desc, graph->valid_places()); // block0_stmt->ResetOp(op_desc, graph->valid_places());
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places());
block0_stmt->SetOp(fake_subgraph_op); block0_stmt->SetOp(fake_subgraph_op);
...@@ -693,10 +693,10 @@ class XPUResNetCbamBlock1Fuser : public FuseBase { ...@@ -693,10 +693,10 @@ class XPUResNetCbamBlock1Fuser : public FuseBase {
auto block1_stmt = matched.at("right_conv1")->stmt(); auto block1_stmt = matched.at("right_conv1")->stmt();
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places());
block1_stmt->SetOp(fake_subgraph_op); block1_stmt->SetOp(fake_subgraph_op);
...@@ -932,10 +932,10 @@ class XPUResNetCbamBlock2Fuser : public FuseBase { ...@@ -932,10 +932,10 @@ class XPUResNetCbamBlock2Fuser : public FuseBase {
<< "Y of last fc must have been transposed"; << "Y of last fc must have been transposed";
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, scope); fake_subgraph_op->Attach(op_desc, scope);
fake_subgraph_op->SetValidPlaces(block2_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block2_stmt->op()->valid_places());
block2_stmt->SetOp(fake_subgraph_op); block2_stmt->SetOp(fake_subgraph_op);
......
...@@ -315,10 +315,10 @@ class XPUResNetBlock0Fuser : public FuseBase { ...@@ -315,10 +315,10 @@ class XPUResNetBlock0Fuser : public FuseBase {
auto block0_stmt = matched.at("left_conv1")->stmt(); auto block0_stmt = matched.at("left_conv1")->stmt();
// block0_stmt->ResetOp(op_desc, graph->valid_places()); // block0_stmt->ResetOp(op_desc, graph->valid_places());
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places());
block0_stmt->SetOp(fake_subgraph_op); block0_stmt->SetOp(fake_subgraph_op);
...@@ -577,10 +577,10 @@ class XPUResNetBlock1Fuser : public FuseBase { ...@@ -577,10 +577,10 @@ class XPUResNetBlock1Fuser : public FuseBase {
auto block1_stmt = matched.at("right_conv1")->stmt(); auto block1_stmt = matched.at("right_conv1")->stmt();
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places());
block1_stmt->SetOp(fake_subgraph_op); block1_stmt->SetOp(fake_subgraph_op);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/conv_conv_fuser.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
bool has_arm = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
has_arm = true;
break;
}
}
if (!has_arm) {
return;
}
// only support fp32 fusion
for (auto conv_has_bias0 : conv_has_bias_cases) {
for (auto conv_has_bias1 : conv_has_bias_cases) {
for (auto conv_type0 : conv_type_cases) {
for (auto conv_type1 : conv_type_cases) {
VLOG(4) << "conv_has_bias0:" << conv_has_bias0
<< " conv_type0:" << conv_type0;
VLOG(4) << "conv_has_bias1:" << conv_has_bias1
<< " conv_type1:" << conv_type1;
fusion::ConvConvFuser fuser(
conv_type0, conv_type1, conv_has_bias0, conv_has_bias1);
fuser(graph.get());
}
}
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_conv_fuse_pass, paddle::lite::mir::ConvConvFusePass)
.BindTargets({TARGET(kARM)});
...@@ -14,18 +14,19 @@ ...@@ -14,18 +14,19 @@
#pragma once #pragma once
#include "lite/backends/xpu/xpu_header_sitter.h" #include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace mir {
namespace xpu {
struct XPUFreeDeleter { class ConvConvFusePass : public ProgramPass {
void operator()(void* p) const { xpu_free(p); } public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
}; };
} // namespace xpu } // namespace mir
} // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuser.h"
#include <memory>
#include <set>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvConvFuser::BuildPattern() {
auto* conv_input0 = VarNode("conv_input0")
->assert_is_op_input(conv_type0_, "Input")
->AsInput();
auto* conv_weight0 = VarNode("conv_weight0")
->assert_is_op_input(conv_type0_, "Filter")
->AsInput();
auto* conv0 = OpNode("conv2d0", conv_type0_)->assert_is_op(conv_type0_);
auto* conv_out0 = VarNode("conv_out0")
->assert_is_op_output(conv_type0_, "Output")
->assert_is_op_input(conv_type1_, "Input")
->AsIntermediate();
auto* conv_weight1 = VarNode("conv_weight1")
->assert_is_op_input(conv_type1_, "Filter")
->AsIntermediate();
auto* conv1 = OpNode("conv2d1", conv_type1_)
->assert_is_op(conv_type1_)
->assert_op_attr<int>("groups", 1)
->AsIntermediate();
auto* conv_out1 = VarNode("conv_out1")
->assert_is_op_output(conv_type1_, "Output")
->AsOutput();
if (conv_has_bias0_) {
if (conv_has_bias1_) {
auto* conv_bias0 = VarNode("conv_bias0")
->assert_is_op_input(conv_type0_, "Bias")
->AsIntermediate();
auto* conv_bias1 = VarNode("conv_bias1")
->assert_is_op_input(conv_type1_, "Bias")
->AsInput();
conv0->LinksFrom({conv_input0, conv_weight0, conv_bias0})
.LinksTo({conv_out0});
conv1->LinksFrom({conv_out0, conv_weight1, conv_bias1})
.LinksTo({conv_out1});
} else {
auto* conv_bias0 = VarNode("conv_bias0")
->assert_is_op_input(conv_type0_, "Bias")
->AsIntermediate();
conv0->LinksFrom({conv_input0, conv_weight0, conv_bias0})
.LinksTo({conv_out0});
conv1->LinksFrom({conv_out0, conv_weight1}).LinksTo({conv_out1});
}
} else {
conv0->LinksFrom({conv_input0, conv_weight0}).LinksTo({conv_out0});
if (conv_has_bias1_) {
auto* conv_bias1 = VarNode("conv_bias1")
->assert_is_op_input(conv_type1_, "Bias")
->AsInput();
conv1->LinksFrom({conv_out0, conv_weight1, conv_bias1})
.LinksTo({conv_out1});
} else {
conv1->LinksFrom({conv_out0, conv_weight1}).LinksTo({conv_out1});
}
}
}
void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto conv_instruct = matched.at("conv2d0")->stmt();
auto conv_op_desc = conv_instruct->mutable_op_info();
auto conv = conv_instruct->op();
auto* scope = conv->scope();
auto conv_op_desc1 = matched.at("conv2d1")->stmt()->mutable_op_info();
// conv0
auto weight0_t = scope->FindVar(matched.at("conv_weight0")->arg()->name)
->GetMutable<lite::Tensor>();
// conv1
auto weight1_t = scope->FindVar(matched.at("conv_weight1")->arg()->name)
->GetMutable<lite::Tensor>();
// auto groups0 = conv_op_desc->GetAttr<int>("groups");
auto groups1 = conv_op_desc1->GetAttr<int>("groups");
auto strides1 = conv_op_desc1->GetAttr<std::vector<int>>("strides");
auto paddings1 = conv_op_desc1->GetAttr<std::vector<int>>("paddings");
auto dilations1 = conv_op_desc1->GetAttr<std::vector<int>>("dilations");
bool enable0_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool enable1_int8 = conv_op_desc1->HasAttr("enable_int8") ? true : false;
int kw = weight1_t->dims()[2];
int kh = weight1_t->dims()[3];
if (!(kw == 1 && kh == 1)) {
return;
}
CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same";
CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1";
CHECK_EQ(weight0_t->dims()[0], weight1_t->dims()[1])
<< "weight0_dims[0] == weight1_dim[1]";
for (int i = 0; i < strides1.size(); i++) {
CHECK_EQ(strides1[i], 1) << "strides[" << i << "]: " << strides1[i]
<< " must be 1";
}
for (int i = 0; i < paddings1.size(); i++) {
CHECK_EQ(paddings1[i], 0) << "paddings1[" << i << "]: " << paddings1[i]
<< " must be 0";
}
for (int i = 0; i < dilations1.size(); i++) {
CHECK_EQ(dilations1[i], 1) << "dilations1[" << i << "]: " << dilations1[i]
<< " must be 1";
}
// comupte new_wight and new bias
///////////////////////////////////////////////////////////////////////////////
// Compute ConvConvFuser
// Before fusion
//
// conv(x) = conv(x) = kx + z = y
// conv(y) = ay + b
//
// After fusion:
//
// conv(conv(x)) = a(kx + z) + b = akx + az + b
//
// new_weights = ak
// new_bias = az + b
///////////////////////////////////////////////////////////////////////////////
if (enable0_int8) {
LOG(FATAL) << "it doesn't support";
return;
} else {
// compute new conv_weight
Tensor weight_tensor;
auto in_dims = weight0_t->dims();
auto weight_dims = weight1_t->dims();
const float* din = weight0_t->data<float>();
const float* weights = weight1_t->data<float>();
int oc0 = in_dims[0];
int ic = in_dims[1];
int ih = in_dims[2];
int iw = in_dims[3];
int oc = weight_dims[0];
weight_tensor.Resize({oc, ic, ih, iw});
float* dout = weight_tensor.mutable_data<float>();
ComputeNewWeight(dout, din, weights, oc0, ic, ih, iw, oc);
weight0_t->CopyDataFrom(weight_tensor);
}
// compute new conv_bias
if (conv_has_bias0_ && conv_op_desc->HasInput("Bias") &&
conv_op_desc->Input("Bias").size() > 0) {
auto bias_t0 = scope->FindVar(matched.at("conv_bias0")->arg()->name)
->GetMutable<lite::Tensor>();
if (conv_has_bias1_ && conv_op_desc1->HasInput("Bias") &&
conv_op_desc1->Input("Bias").size() > 0) {
auto bias_t1 = scope->FindVar(matched.at("conv_bias1")->arg()->name)
->GetMutable<lite::Tensor>();
Tensor bias;
bias.CopyDataFrom(*bias_t1);
auto bias_data = bias.mutable_data<float>();
ComputeNewBias(bias_data, bias_t0, weight1_t, bias_t1);
bias_t1->CopyDataFrom(bias);
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias1")->arg()->name}); // conv_bias
IR_NODE_LINK_TO(matched.at("conv_bias1"), matched.at("conv2d0"));
} else {
Tensor bias;
auto weight_dims = weight1_t->dims();
bias.Resize({weight_dims[0]});
auto bias_d = bias.mutable_data<float>();
ComputeNewBias(bias_d, bias_t0, weight1_t, nullptr);
bias_t0->CopyDataFrom(bias);
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias0")->arg()->name}); // conv_bias
}
} else {
if (conv_has_bias1_ && conv_op_desc1->HasInput("Bias") &&
conv_op_desc1->Input("Bias").size() > 0) {
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias1")->arg()->name}); // conv_bias
IR_NODE_LINK_TO(matched.at("conv_bias1"), matched.at("conv2d0"));
}
}
conv_op_desc->SetType(conv_type0_);
conv_op_desc->SetInput("Input", {matched.at("conv_input0")->arg()->name});
conv_op_desc->SetInput("Filter", {matched.at("conv_weight0")->arg()->name});
conv_op_desc->SetOutput("Output", {matched.at("conv_out1")->arg()->name});
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
IR_OP_VAR_LINK(matched.at("conv2d0"), matched.at("conv_out1"));
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvConvFuser : public FuseBase {
public:
explicit ConvConvFuser(const std::string& conv_type0,
const std::string& conv_type1,
const bool conv_has_bias0,
const bool conv_has_bias1)
: conv_type0_(conv_type0),
conv_type1_(conv_type1),
conv_has_bias0_(conv_has_bias0),
conv_has_bias1_(conv_has_bias1) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
void ComputeNewWeight(float* dout,
const float* din,
const float* weights,
int oc0,
int ic,
int ih,
int iw,
int oc1) {
// input conv_weight0_t weights conv_weight1_t
// output weight_tensor
// ksize = 1
int in_size = ih * iw;
int in_channel_size = ic * in_size;
// out = w1[j, i, ih, iw] * w2[k, j, kw, kh]
// out_dim = [oc1, ic, kh, kw], din_dim = [oc0, ic, kh, kw]
// weight_dim = [oc1, oc0, kh, kw]
for (int k = 0; k < oc1; k++) {
const float* weights_ptr = weights + k * oc0;
float* out_ptr = dout + k * in_channel_size;
for (int c = 0; c < ic; c++) {
float* out_ptr_channel = out_ptr + c * in_size;
const float* din_ptr = din + c * in_size;
for (int i = 0; i < in_size; i++) {
float sum = 0.f;
for (int j = 0; j < oc0; j++) {
sum += din_ptr[j * in_channel_size] * weights_ptr[j];
}
*out_ptr_channel++ = sum;
}
}
}
}
void ComputeNewBias(float* dout,
Tensor* bias0_tensor,
Tensor* weight_tensor,
Tensor* bias1_tensor) {
// input bias0_tensor weight_tensor bias1_tensor
// output bias_tensor
auto in_dims = bias0_tensor->dims();
auto weight_dims = weight_tensor->dims();
const float* din = bias0_tensor->data<float>();
const float* weights = weight_tensor->data<float>();
int ic = in_dims[0];
int oc = weight_dims[0];
// out_k = b0[num, j, 1, 1] * w2[k, j, 1, 1]
if (bias1_tensor) {
const float* din2 = bias1_tensor->data<float>();
for (int k = 0; k < oc; k++) {
const float* weights_ptr = weights + k * ic;
float sum = 0.f;
for (int j = 0; j < ic; j++) {
sum += din[j] * weights_ptr[j];
}
dout[k] = sum + din2[k];
}
} else {
for (int k = 0; k < oc; k++) {
const float* weights_ptr = weights + k * ic;
float sum = 0.f;
for (int j = 0; j < ic; j++) {
sum += din[j] * weights_ptr[j];
}
dout[k] = sum;
}
}
}
private:
std::string conv_type0_{"conv2d"};
std::string conv_type1_{"conv2d"};
bool conv_has_bias0_{false};
bool conv_has_bias1_{false};
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -175,7 +175,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -175,7 +175,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
for (int i = 0; i < weight_scale_size; i++) { for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale); weight_scale.push_back(whole_weight_scale);
} }
// Arm CPU does not support conv2d_transpose
if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
}
op_desc.SetInputScale(weight_name, weight_scale); op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type. // change the weight from the float type to int8 type.
...@@ -280,6 +284,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -280,6 +284,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
} }
// Arm CPU does not support conv2d_transpose
if (quantized_op_type_ != "conv2d_transpose") { if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
} }
......
...@@ -39,6 +39,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -39,6 +39,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
nodes_in_order = graph->StmtTopologicalOrder(); nodes_in_order = graph->StmtTopologicalOrder();
} }
insts_.emplace_back();
for (auto& item : nodes_in_order) { for (auto& item : nodes_in_order) {
if (item->IsStmt()) { if (item->IsStmt()) {
auto& stmt = item->AsStmt(); auto& stmt = item->AsStmt();
...@@ -57,7 +58,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -57,7 +58,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
.SetSyncStreams(stmt.sync_streams_); .SetSyncStreams(stmt.sync_streams_);
} }
#endif #endif
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); insts_.back().emplace_back(stmt.op(), std::move(stmt.kernels().front()));
} }
} }
} }
......
...@@ -42,7 +42,7 @@ class GenerateProgramPass : public ProgramPass { ...@@ -42,7 +42,7 @@ class GenerateProgramPass : public ProgramPass {
} }
private: private:
std::vector<Instruction> insts_; std::vector<std::vector<Instruction>> insts_;
}; };
} // namespace mir } // namespace mir
......
...@@ -284,13 +284,19 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph, ...@@ -284,13 +284,19 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
head_node->AsArg().name, head_node->AsArg().name,
cur_node->AsArg().name); cur_node->AsArg().name);
// for subgraph op, modify the BlockDesc // for subgraph op, modify the BlockDesc
auto* sub_block_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>( auto sub_program_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>(
inst_node->AsStmt().op().get()) inst_node->AsStmt().op().get())
->GetSubBlock(); ->GetProgramDesc();
for (size_t i = 0; i < sub_block_desc->OpsSize(); ++i) { CHECK(sub_program_desc);
auto* sub_block_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(i); int sub_block_idx =
UpdateInputTo( inst_node->AsStmt().op()->op_info()->GetAttr<int32_t>("sub_block");
sub_block_op_desc, head_node->AsArg().name, cur_node->AsArg().name); auto* sub_block_desc =
sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx);
for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
++sub_op_idx) {
auto* sub_op_desc = const_cast<cpp::OpDesc*>(
sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx));
UpdateInputTo(sub_op_desc, head_node->AsArg().name, cur_node->AsArg().name);
} }
// recreate the op // recreate the op
...@@ -444,21 +450,27 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph, ...@@ -444,21 +450,27 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
tail_node->AsArg().name, tail_node->AsArg().name,
cur_node->AsArg().name); cur_node->AsArg().name);
// for subgraph op, modify the BlockDesc // for subgraph op, modify the BlockDesc
auto* sub_block_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>( auto sub_program_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>(
inst_node->AsStmt().op().get()) inst_node->AsStmt().op().get())
->GetSubBlock(); ->GetProgramDesc();
for (size_t i = 0; i < sub_block_desc->OpsSize(); ++i) { CHECK(sub_program_desc);
auto* sub_block_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(i); int sub_block_idx =
inst_node->AsStmt().op()->op_info()->GetAttr<int32_t>("sub_block");
auto* sub_block_desc =
sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx);
for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
++sub_op_idx) {
auto* sub_op_desc = const_cast<cpp::OpDesc*>(
sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx));
UpdateOutputTo( UpdateOutputTo(
sub_block_op_desc, tail_node->AsArg().name, cur_node->AsArg().name); sub_op_desc, tail_node->AsArg().name, cur_node->AsArg().name);
/* graph like this /* graph like this
* subgraph_op_0 * subgraph_op_0
* / \ * / \
* / \ * / \
* subgraph_op_1 host_op * subgraph_op_1 host_op
*/ */
UpdateInputTo( UpdateInputTo(sub_op_desc, tail_node->AsArg().name, cur_node->AsArg().name);
sub_block_op_desc, tail_node->AsArg().name, cur_node->AsArg().name);
} }
// recreate the op // recreate the op
...@@ -482,15 +494,22 @@ void MLUPostprocessPass::RecreateOp(Node* inst_node, SSAGraph* graph) { ...@@ -482,15 +494,22 @@ void MLUPostprocessPass::RecreateOp(Node* inst_node, SSAGraph* graph) {
} }
} }
bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node, Node* inst) { bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node,
auto* block_desc = Node* inst_node) {
static_cast<operators::SubgraphOp*>(inst->AsStmt().op().get()) auto sub_program_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>(
->GetSubBlock(); inst_node->AsStmt().op().get())
for (size_t op_idx = 0; op_idx < block_desc->OpsSize(); op_idx++) { ->GetProgramDesc();
auto op_desc = block_desc->GetOp<cpp::OpDesc>(op_idx); CHECK(sub_program_desc);
CHECK(op_desc); int sub_block_idx =
if (op_desc->Type() == "conv2d") { inst_node->AsStmt().op()->op_info()->GetAttr<int32_t>("sub_block");
for (auto& names : op_desc->inputs()) { auto* sub_block_desc =
sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx);
for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
sub_op_idx++) {
auto sub_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx);
CHECK(sub_op_desc);
if (sub_op_desc->Type() == "conv2d") {
for (auto& names : sub_op_desc->inputs()) {
if (std::find(names.second.begin(), if (std::find(names.second.begin(),
names.second.end(), names.second.end(),
arg_node->AsArg().name) != names.second.end()) { arg_node->AsArg().name) != names.second.end()) {
...@@ -746,19 +765,23 @@ std::pair<bool, std::string> CheckOutputAndInsert( ...@@ -746,19 +765,23 @@ std::pair<bool, std::string> CheckOutputAndInsert(
// insert cast op on mlu, to avoid cast on cpu // insert cast op on mlu, to avoid cast on cpu
void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
const Type* subgraph_type) { const Type* subgraph_type) {
auto subgraph_op = subgraph_node->AsStmt().op(); CHECK_EQ(subgraph_node->AsStmt().op()->Type(), "subgraph");
CHECK_EQ(subgraph_op->Type(), "subgraph"); auto subgraph_op =
auto op = dynamic_cast<operators::SubgraphOp*>(subgraph_op.get()); dynamic_cast<operators::SubgraphOp*>(subgraph_node->AsStmt().op().get());
CHECK(op); CHECK(subgraph_op);
auto block_desc = op->GetSubBlock(); auto sub_program_desc = subgraph_op->GetProgramDesc();
CHECK(sub_program_desc);
int sub_block_idx = subgraph_op->op_info()->GetAttr<int32_t>("sub_block");
auto* sub_block_desc = const_cast<cpp::BlockDesc*>(
sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx));
// create a new block desc to keep op sequence correct // create a new block desc to keep op sequence correct
cpp::BlockDesc* new_block_desc = new cpp::BlockDesc(); cpp::BlockDesc new_block_desc;
new_block_desc->ClearOps(); new_block_desc.ClearOps();
new_block_desc->ClearVars(); new_block_desc.ClearVars();
new_block_desc->SetIdx(block_desc->Idx()); new_block_desc.SetIdx(sub_block_desc->Idx());
new_block_desc->SetParentIdx(block_desc->ParentIdx()); new_block_desc.SetParentIdx(sub_block_desc->ParentIdx());
new_block_desc->SetForwardBlockIdx(block_desc->ForwardBlockIdx()); new_block_desc.SetForwardBlockIdx(sub_block_desc->ForwardBlockIdx());
// find all IO that is not weight or persist // find all IO that is not weight or persist
std::list<std::string> i_names, o_names; std::list<std::string> i_names, o_names;
...@@ -769,8 +792,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -769,8 +792,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
auto input_name = input->AsArg().name; auto input_name = input->AsArg().name;
if (!(input->AsArg().is_weight || input->AsArg().is_persist)) { if (!(input->AsArg().is_weight || input->AsArg().is_persist)) {
i_names.emplace_back(input_name); i_names.emplace_back(input_name);
auto ret = CheckInputAndInsert(op->scope(), auto ret = CheckInputAndInsert(subgraph_op->scope(),
new_block_desc, &new_block_desc,
input_name, input_name,
input->AsArg().type, input->AsArg().type,
subgraph_type); subgraph_type);
...@@ -783,8 +806,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -783,8 +806,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
auto output_name = output->AsArg().name; auto output_name = output->AsArg().name;
if (!(output->AsArg().is_weight || output->AsArg().is_persist)) { if (!(output->AsArg().is_weight || output->AsArg().is_persist)) {
o_names.emplace_back(output_name); o_names.emplace_back(output_name);
auto ret = CheckOutputAndInsert(op->scope(), auto ret = CheckOutputAndInsert(subgraph_op->scope(),
block_desc, sub_block_desc,
output_name, output_name,
output->AsArg().type, output->AsArg().type,
subgraph_type); subgraph_type);
...@@ -795,46 +818,48 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -795,46 +818,48 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
} }
// update input and output // update input and output
for (size_t op_idx = 0; op_idx < block_desc->OpsSize(); ++op_idx) { for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
auto desc = block_desc->GetOp<cpp::OpDesc>(op_idx); ++sub_op_idx) {
auto new_desc = new_block_desc->AddOp<cpp::OpDesc>(); auto sub_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx);
*new_desc = *desc; auto new_op_desc = new_block_desc.AddOp<cpp::OpDesc>();
*new_op_desc = *sub_op_desc;
if (desc->Type() != "layout" && desc->Type() != "cast") {
auto op_input_args = new_desc->InputArgumentNames(); if (sub_op_desc->Type() != "layout" && sub_op_desc->Type() != "cast") {
auto op_input_args = new_op_desc->InputArgumentNames();
for (auto& input_arg : op_input_args) { for (auto& input_arg : op_input_args) {
auto op_input = new_desc->Input(input_arg); auto op_input = new_op_desc->Input(input_arg);
for (auto& it : i_names) { for (auto& it : i_names) {
auto index = std::find(op_input.begin(), op_input.end(), it); auto index = std::find(op_input.begin(), op_input.end(), it);
if (index != op_input.end() && if (index != op_input.end() &&
node_replace.find(it) != node_replace.end()) { node_replace.find(it) != node_replace.end()) {
index = op_input.erase(index); index = op_input.erase(index);
op_input.emplace(index, node_replace.at(it)); op_input.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change input from " << it VLOG(4) << new_op_desc->Type() << "] change input from " << it
<< " to " << node_replace.at(it); << " to " << node_replace.at(it);
} }
} }
new_desc->SetInput(input_arg, op_input); new_op_desc->SetInput(input_arg, op_input);
} }
auto op_output_args = new_desc->OutputArgumentNames(); auto op_output_args = new_op_desc->OutputArgumentNames();
for (auto& output_arg : op_output_args) { for (auto& output_arg : op_output_args) {
auto op_output = new_desc->Output(output_arg); auto op_output = new_op_desc->Output(output_arg);
for (auto& it : o_names) { for (auto& it : o_names) {
auto index = std::find(op_output.begin(), op_output.end(), it); auto index = std::find(op_output.begin(), op_output.end(), it);
if (index != op_output.end() && if (index != op_output.end() &&
node_replace.find(it) != node_replace.end()) { node_replace.find(it) != node_replace.end()) {
index = op_output.erase(index); index = op_output.erase(index);
op_output.emplace(index, node_replace.at(it)); op_output.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change output from " << it VLOG(4) << new_op_desc->Type() << "] change output from " << it
<< " to " << node_replace.at(it); << " to " << node_replace.at(it);
} }
} }
new_desc->SetOutput(output_arg, op_output); new_op_desc->SetOutput(output_arg, op_output);
} }
} }
} }
op->SetSubBlock(new_block_desc);
*sub_block_desc = new_block_desc;
} }
void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) { void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) {
......
...@@ -153,60 +153,61 @@ Node *SSAGraph::GraphCreateInstructNode( ...@@ -153,60 +153,61 @@ Node *SSAGraph::GraphCreateInstructNode(
} }
void SSAGraph::Build(const Program &program, void SSAGraph::Build(const Program &program,
const std::vector<Place> &valid_places) { const std::vector<Place> &valid_places,
int block_idx) {
CHECK(node_storage_.empty()); CHECK(node_storage_.empty());
auto weights_name = program.weights(); auto weights = program.weights();
auto is_weights = [&](const std::string &name) -> bool { auto is_weight = [&](const std::string &name) -> bool {
auto it = std::find(weights_name.begin(), weights_name.end(), name); auto it = std::find(weights.begin(), weights.end(), name);
if (it == weights_name.end()) return false; if (it == weights.end()) return false;
return true; return true;
}; };
std::map<std::string, PrecisionType> var_types = program.var_data_type(); auto var_type_map = program.var_type_map();
std::map<std::string, mir::Node *> arg_update_node_map;
std::map<std::string, mir::Node *> arg_update_node_map_; for (auto &op : program.ops(block_idx)) {
for (auto &op : program.ops()) {
VLOG(3) << op->op_info()->Type(); VLOG(3) << op->op_info()->Type();
auto *op_node = GraphCreateInstructNode(op, valid_places); auto *op_node = GraphCreateInstructNode(op, valid_places);
for (const std::string &name : op->op_info()->input_names()) { auto *op_info = op->op_info();
const auto &op_type = op_info->Type();
for (const auto &var_name : op_info->input_names()) {
mir::Node *arg_node = nullptr; mir::Node *arg_node = nullptr;
if (arg_update_node_map_.count(name)) { if (arg_update_node_map.count(var_name)) {
arg_node = arg_update_node_map_.at(name); arg_node = arg_update_node_map.at(var_name);
} else { } else {
node_storage_.emplace_back(); node_storage_.emplace_back();
arg_node = &node_storage_.back(); arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(var_name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map[var_name] = arg_node;
} }
if (var_types.count(name)) { if (var_type_map.count(var_name)) {
if (!arg_node->arg()->type) { if (!arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy( arg_node->arg()->type = var_type_map[var_name];
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
} }
// Store the original data type of the output tensors for // Store the original data type of the output tensors for
// type_precision_cast_pass, to keep the consistency between the // type_precision_cast_pass, to keep the consistency between the
// output types of original graph and optimized graph's // output types of original graph and optimized graph's
if (op->op_info()->Type() == "fetch") { if (op_type == "fetch") {
op->mutable_op_info()->SetAttr<int>( op->mutable_op_info()->SetAttr<int>(
"data_type", static_cast<int>(var_types[name])); "data_type",
static_cast<int>(var_type_map[var_name]->precision()));
} }
} }
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weight(var_name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
DirectedLink(arg_node, op_node); DirectedLink(arg_node, op_node);
} }
for (const std::string &name : op->op_info()->output_names()) { for (const auto &var_name : op->op_info()->output_names()) {
node_storage_.emplace_back(); node_storage_.emplace_back();
auto *arg_node = &node_storage_.back(); auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(var_name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map[var_name] = arg_node;
if (var_types.count(name) && !arg_node->arg()->type) { if (var_type_map.count(var_name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy( arg_node->arg()->type = var_type_map[var_name];
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
} }
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weight(var_name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
DirectedLink(op_node, arg_node); DirectedLink(op_node, arg_node);
} }
......
...@@ -35,9 +35,13 @@ class GraphBase {}; ...@@ -35,9 +35,13 @@ class GraphBase {};
class SSAGraph : GraphBase { class SSAGraph : GraphBase {
public: public:
// @param program: the op program // @param program: the target program with vars and ops
// @param valid_places: the valid places user set for the system. // @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places); // @param block_idx: the block index in the target program, default is 0(main
// block)
void Build(const Program &program,
const std::vector<Place> &valid_places,
int block_idx = kRootBlockIdx);
void RemoveNode(const mir::Node *node); void RemoveNode(const mir::Node *node);
std::vector<mir::Node *> StmtTopologicalOrder(); std::vector<mir::Node *> StmtTopologicalOrder();
......
...@@ -411,16 +411,17 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -411,16 +411,17 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
cpp::OpDesc subgraph_op_desc; cpp::OpDesc subgraph_op_desc;
subgraph_op_desc.SetType("subgraph"); subgraph_op_desc.SetType("subgraph");
// Create a new sub block desc for storing all of Ops and Vars of the target // Create a program desc and a block desc for storing all of Ops and Vars of
// subgraph and sub_block_idx is set as a attribute of subgraph op, // the target subgraph and sub_block_idx is set as a attribute of subgraph op,
// sub_block_idx < 0 means it's a new subgraph op // sub_block_idx = 0 means it's a new subgraph op
int sub_block_idx = -(subgraph_idx + 1); auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); int sub_block_idx = 0;
auto sub_block_desc = sub_program_desc->AddBlock<cpp::BlockDesc>();
sub_block_desc->ClearOps(); sub_block_desc->ClearOps();
sub_block_desc->ClearVars(); sub_block_desc->ClearVars();
for (auto &op_node : subgraph_nodes) { for (auto &op_node : subgraph_nodes) {
auto sub_block_op_desc = sub_block_desc->AddOp<cpp::OpDesc>(); auto sub_op_desc = sub_block_desc->AddOp<cpp::OpDesc>();
*sub_block_op_desc = *op_node->AsStmt().op_info(); *sub_op_desc = *op_node->AsStmt().op_info();
} }
subgraph_op_desc.SetAttr<int32_t>("sub_block", sub_block_idx); subgraph_op_desc.SetAttr<int32_t>("sub_block", sub_block_idx);
...@@ -437,13 +438,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -437,13 +438,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
&local_var_nodes, &local_var_nodes,
&unused_var_nodes); &unused_var_nodes);
// A simplified model without the original weight/local/unused nodes on the // A simplified model without the original weight/local/unused nodes on the
// subgraph ops will be saved only if 'SUBGRAPH_DISABLE_ONLINE_MODE' is set to // subgraph ops will be saved only if 'SUBGRAPH_ONLINE_MODE' is set to
// true and Predictor->Run(...), Predictor->Save(...) is called. // true(default) and Predictor->Run(...), Predictor->Save(...) is called.
std::set<Node *> input_var_nodes(idata_var_nodes.begin(), std::set<Node *> input_var_nodes(idata_var_nodes.begin(),
idata_var_nodes.end()); idata_var_nodes.end());
std::set<Node *> output_var_nodes(odata_var_nodes.begin(), std::set<Node *> output_var_nodes(odata_var_nodes.begin(),
odata_var_nodes.end()); odata_var_nodes.end());
if (!GetBoolFromEnv(SUBGRAPH_DISABLE_ONLINE_MODE)) { if (GetBoolFromEnv(SUBGRAPH_ONLINE_MODE, true)) {
input_var_nodes.insert(weight_var_nodes.begin(), weight_var_nodes.end()); input_var_nodes.insert(weight_var_nodes.begin(), weight_var_nodes.end());
output_var_nodes.insert(local_var_nodes.begin(), local_var_nodes.end()); output_var_nodes.insert(local_var_nodes.begin(), local_var_nodes.end());
output_var_nodes.insert(unused_var_nodes.begin(), unused_var_nodes.end()); output_var_nodes.insert(unused_var_nodes.begin(), unused_var_nodes.end());
...@@ -476,7 +477,7 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -476,7 +477,7 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
subgraph_op_desc.SetOutput("Outputs", output_var_names); subgraph_op_desc.SetOutput("Outputs", output_var_names);
auto subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto subgraph_op = LiteOpRegistry::Global().Create("subgraph");
static_cast<operators::SubgraphOp *>(subgraph_op.get()) static_cast<operators::SubgraphOp *>(subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
auto any_op = (*subgraph_nodes.begin())->AsStmt().op(); auto any_op = (*subgraph_nodes.begin())->AsStmt().op();
subgraph_op->Attach(subgraph_op_desc, any_op->scope()); subgraph_op->Attach(subgraph_op_desc, any_op->scope());
......
...@@ -141,12 +141,11 @@ std::vector<std::string> AddFetchDesc( ...@@ -141,12 +141,11 @@ std::vector<std::string> AddFetchDesc(
} }
TEST(Subgraph, detect_simple_model) { TEST(Subgraph, detect_simple_model) {
cpp::ProgramDesc program_desc; auto program_desc = std::make_shared<cpp::ProgramDesc>();
std::vector<Place> valid_places{{TARGET(kHost), PRECISION(kFloat)}}; std::vector<Place> valid_places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
// Build a simple network // Build a simple network
program_desc.ClearBlocks(); auto* block_desc = program_desc->AddBlock<cpp::BlockDesc>();
auto* block_desc = program_desc.AddBlock<cpp::BlockDesc>();
block_desc->ClearOps(); block_desc->ClearOps();
block_desc->ClearVars(); block_desc->ClearVars();
auto* var_desc = block_desc->AddVar<cpp::VarDesc>(); auto* var_desc = block_desc->AddVar<cpp::VarDesc>();
...@@ -181,13 +180,13 @@ TEST(Subgraph, detect_custom_model) { ...@@ -181,13 +180,13 @@ TEST(Subgraph, detect_custom_model) {
"the path of model files."; "the path of model files.";
return; return;
} }
cpp::ProgramDesc program_desc; auto program_desc = std::make_shared<cpp::ProgramDesc>();
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
LoadModelPb(FLAGS_model_dir, LoadModelPb(FLAGS_model_dir,
FLAGS_model_file, FLAGS_model_file,
FLAGS_params_file, FLAGS_params_file,
scope.get(), scope.get(),
&program_desc, program_desc.get(),
!FLAGS_model_file.empty() && !FLAGS_params_file.empty(), !FLAGS_model_file.empty() && !FLAGS_params_file.empty(),
false); false);
std::vector<Place> valid_places({ std::vector<Place> valid_places({
......
...@@ -36,14 +36,20 @@ void UpdateInputsForSubgraph(OpLite* op, ...@@ -36,14 +36,20 @@ void UpdateInputsForSubgraph(OpLite* op,
op_desc->GetAttr<std::vector<std::string>>("input_data_names"); op_desc->GetAttr<std::vector<std::string>>("input_data_names");
std::replace(input_data_names.begin(), input_data_names.end(), from, to); std::replace(input_data_names.begin(), input_data_names.end(), from, to);
op_desc->SetAttr("input_data_names", input_data_names); op_desc->SetAttr("input_data_names", input_data_names);
auto* subblock_desc = static_cast<operators::SubgraphOp*>(op)->GetSubBlock(); auto sub_program_desc =
CHECK(subblock_desc); static_cast<operators::SubgraphOp*>(op)->GetProgramDesc();
for (size_t i = 0; i < subblock_desc->OpsSize(); i++) { CHECK(sub_program_desc);
auto* subblock_op_desc = subblock_desc->GetOp<cpp::OpDesc>(i); int sub_block_idx = op_desc->GetAttr<int32_t>("sub_block");
for (auto& subblock_op_input : *subblock_op_desc->mutable_inputs()) { auto sub_block_desc =
for (auto& subblock_var_name : subblock_op_input.second) { sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx);
if (subblock_var_name == from) { for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
subblock_var_name = to; sub_op_idx++) {
auto sub_op_desc = const_cast<cpp::OpDesc*>(
sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx));
for (auto& sub_op_input : *sub_op_desc->mutable_inputs()) {
for (auto& sub_var_name : sub_op_input.second) {
if (sub_var_name == from) {
sub_var_name = to;
} }
} }
} }
......
...@@ -59,25 +59,46 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -59,25 +59,46 @@ class VariablePlaceInferencePass : public DebugPass {
} }
// Set the type of the weight // Set the type of the weight
void SetWeightType(Node* w, void SetWeightType(Node* weight_node,
const LiteType& type, const LiteType& type,
const std::map<std::string, bool>& lite_with_targets) { const std::map<std::string, bool>& with_targets) {
VLOG(4) << "type.precision():" << PrecisionRepr(type.precision()); VLOG(4) << "type.precision():" << PrecisionRepr(type.precision());
if (lite_with_targets.at("kFPGA")) { if (with_targets.at("kFPGA")) {
w->AsArg().type = LiteType::GetTensorTy( weight_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
} else if (lite_with_targets.at("kOpenCL")) { } else if (with_targets.at("kOpenCL")) {
w->AsArg().type = LiteType::GetTensorTy( weight_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
} else if (lite_with_targets.at("kCUDA")) { } else if (with_targets.at("kCUDA")) {
w->AsArg().type = LiteType::GetTensorTy( weight_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
} else { } else {
w->AsArg().type = LiteType::GetTensorTy( weight_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), type.precision(), DATALAYOUT(kNCHW)); TARGET(kHost), type.precision(), DATALAYOUT(kNCHW));
} }
} }
// Update a's kUnk fields from b's fields.
void UpdateTypeFrom(const Type** a, const Type* b) {
auto target = (*a)->target();
auto precision = (*a)->precision();
auto layout = (*a)->layout();
if (target == TARGET(kUnk)) {
target = b->target();
}
if (precision == PRECISION(kUnk)) {
precision = b->precision();
}
if (layout == DATALAYOUT(kUnk)) {
layout = b->layout();
}
if ((*a)->IsTensor() && b->IsTensor()) {
*a = LiteType::GetTensorTy(target, precision, layout);
} else if ((*a)->IsTensorList() && b->IsTensorList()) {
*a = LiteType::GetTensorListTy(target, precision, layout);
}
}
void InferenceArgumentPlace(SSAGraph* graph) { void InferenceArgumentPlace(SSAGraph* graph) {
auto& valid_places = graph->valid_places(); auto& valid_places = graph->valid_places();
auto valid_places_has_target = [&](TargetType t) -> bool { auto valid_places_has_target = [&](TargetType t) -> bool {
...@@ -88,122 +109,90 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -88,122 +109,90 @@ class VariablePlaceInferencePass : public DebugPass {
} }
return false; return false;
}; };
std::map<std::string, bool> lite_with_targets{ std::map<std::string, bool> with_targets{
{"kOpenCL", valid_places_has_target(TARGET(kOpenCL))}, {"kOpenCL", valid_places_has_target(TARGET(kOpenCL))},
{"kCUDA", valid_places_has_target(TARGET(kCUDA))}, {"kCUDA", valid_places_has_target(TARGET(kCUDA))},
{"kFPGA", valid_places_has_target(TARGET(kFPGA))}}; {"kFPGA", valid_places_has_target(TARGET(kFPGA))}};
VLOG(4) << "lite_with_targets['kOpenCL']:" << lite_with_targets["kOpenCL"]; VLOG(4) << "with_targets['kOpenCL']:" << with_targets["kOpenCL"];
VLOG(4) << "lite_with_targets['kFPGA']:" << lite_with_targets["kFPGA"]; VLOG(4) << "with_targets['kFPGA']:" << with_targets["kFPGA"];
VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->StmtTopologicalOrder()) { for (auto& node : graph->StmtTopologicalOrder()) {
auto& inst = x->AsStmt(); auto& inst = node->AsStmt();
const auto* op_info = inst.op_info();
const auto& op_type = op_info->Type();
auto& kernel = inst.picked_kernel();
// The IoCopyOp is a tool operator, it won't support the type inference. // The IoCopyOp is a tool operator, it won't support the type inference.
// in fpga, we has io_copy+cali+layout tool ops, so we need type inference // in fpga, we has io_copy+cali+layout tool ops, so we need type inference
// for // for tool operator
// tool operator if ((!with_targets["kFPGA"]) && (!with_targets["kOpenCL"])) {
if ((!lite_with_targets["kFPGA"]) && (!lite_with_targets["kOpenCL"])) { VLOG(3) << "skip 'io_copy' if target is FPGA and OpenCL";
VLOG(3) << "inst.op_type() == 'io_copy', continue"; if (op_type == "io_copy") continue;
if (inst.op_type() == "io_copy") continue; }
}
// deal with inputs
VLOG(4) << "Infering op " << inst.op_info()->Repr();
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto get_argname = [&](
const std::string& node_name,
const std::map<std::string, std::vector<std::string>>& argname_map)
-> std::string {
for (auto& ele : argname_map) {
auto it =
std::find(ele.second.begin(), ele.second.end(), node_name);
if (it != ele.second.end()) return ele.first;
}
return "";
};
for (auto* x_in : x->inlinks) { // Infering the input and output variable's place according to the
std::string node_name = x_in->AsArg().name; // declaration of I/O arguments of the picked kernel of the op
std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); VLOG(4) << "Op " << op_info->Repr();
CHECK(arg_name.size() > 0) << "can not found op arguments for node " for (auto* in_node : node->inlinks) {
<< node_name; auto& var = in_node->AsArg();
VLOG(4) << "-- input arg_name:" << arg_name << " " const auto& var_name = var.name;
<< "-- node name:" << node_name; auto* var_type = &var.type;
auto type = inst.picked_kernel().GetInputDeclType(arg_name); std::string arg_name;
if (!x_in->AsArg().type) { CHECK(op_info->GetInputArgname(var_name, &arg_name))
VLOG(4) << "set type " << *type << " " << x_in->AsArg().name; << "Can not find the input argument for var " << var_name;
if (x_in->AsArg().is_weight) { VLOG(4) << " - input arg name:" << arg_name << " var name:" << var_name;
SetWeightType(x_in, *type, lite_with_targets); const auto* decl_type = kernel.GetInputDeclType(arg_name);
if (!(*var_type)) {
VLOG(4) << "set type " << *decl_type << " " << var_name;
if (var.is_weight) {
SetWeightType(in_node, *decl_type, with_targets);
} else { } else {
x_in->AsArg().type = type; *var_type = decl_type;
} }
} else if (x_in->AsArg().type->target() == TARGET(kUnk) && } else if (!(*var_type)->place().is_valid()) {
x_in->AsArg().type->precision() != PRECISION(kUnk) &&
x_in->AsArg().type->layout() == DATALAYOUT(kUnk)) {
// If is quantization, infer the Int8 type. // If is quantization, infer the Int8 type.
if (type->precision() == PRECISION(kInt8)) { if (decl_type->precision() == PRECISION(kInt8)) {
x_in->AsArg().type = type; *var_type = decl_type;
} else { } else {
PrecisionType tmp_ptype = x_in->AsArg().type->precision(); UpdateTypeFrom(var_type, decl_type);
x_in->AsArg().type = LiteType::GetTensorTy( }
type->target(), tmp_ptype, type->layout()); }
} }
} for (auto* out_node : node->outlinks) {
} auto& var = out_node->AsArg();
const auto& var_name = var.name;
VLOG(4) << "inst " << inst.op_info()->Repr(); auto* var_type = &var.type;
for (auto* x_out : x->outlinks) { std::string arg_name;
std::string node_name = x_out->AsArg().name; CHECK(op_info->GetOutputArgname(var_name, &arg_name))
std::string arg_name = << "Can not find the output argument for var " << var_name;
get_argname(node_name, inst.op_info()->outputs()); VLOG(4) << " - output arg name:" << arg_name
CHECK(arg_name.size() > 0) << "can not found op arguments for node " << " var name:" << var_name;
<< node_name << " in Inst " const auto* decl_type = kernel.GetOutputDeclType(arg_name);
<< inst.op_type(); if (!(*var_type)) {
VLOG(4) << "-- output arg_name " << arg_name; VLOG(4) << "set type " << *decl_type << " " << var_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name); if (var.is_weight) {
if (!x_out->AsArg().type) { SetWeightType(out_node, *decl_type, with_targets);
VLOG(4) << "set type " << *type << " " << x_out->AsArg().name;
if (x_out->AsArg().is_weight) {
SetWeightType(x_out, *type, lite_with_targets);
} else { } else {
x_out->AsArg().type = type; *var_type = decl_type;
} }
} else if (x_out->AsArg().type->target() == TARGET(kUnk) && } else if (!(*var_type)->place().is_valid()) {
x_out->AsArg().type->precision() != PRECISION(kUnk) &&
x_out->AsArg().type->layout() == DATALAYOUT(kUnk)) {
// If is quantization, infer the Int8 type. // If is quantization, infer the Int8 type.
if (type->precision() == PRECISION(kInt8)) { if (decl_type->precision() == PRECISION(kInt8) ||
x_out->AsArg().type = type; (decl_type->precision() == PRECISION(kFP16) &&
} else if (type->precision() == PRECISION(kFP16) && decl_type->target() != TARGET(kOpenCL))) {
type->target() != TARGET(kOpenCL)) { *var_type = decl_type;
x_out->AsArg().type = type;
} else { } else {
PrecisionType tmp_ptype = x_out->AsArg().type->precision(); UpdateTypeFrom(var_type, decl_type);
x_out->AsArg().type = LiteType::GetTensorTy(
type->target(), tmp_ptype, type->layout());
}
} }
} }
} }
} }
// Update me's kUnk fields by other's fields.
void UpdatePlace(Place* me, const Place& other) {
CHECK(other.is_valid());
if (me->target == TARGET(kUnk)) {
me->target = other.target;
}
if (me->precision == PRECISION(kUnk)) {
me->precision = other.precision;
}
if (me->layout == DATALAYOUT(kUnk)) {
me->layout = other.layout;
}
} }
private: private:
// The default target for arguments, e.g. load weights to CPU memory for CUDA // The default target for arguments, e.g. load weights to CPU memory for
// computation by default. // CUDA computation by default.
TargetType argument_default_target_{TARGET(kHost)}; TargetType argument_default_target_{TARGET(kHost)};
}; };
......
...@@ -99,7 +99,7 @@ class OpLite : public Registry { ...@@ -99,7 +99,7 @@ class OpLite : public Registry {
std::vector<std::unique_ptr<KernelBase>> CreateKernels( std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = ""); const std::vector<Place> &places, const std::string &kernel_type = "");
lite::Scope *scope() { return scope_; } Scope *scope() { return scope_; }
// Assign op param to kernel. // Assign op param to kernel.
virtual void AttachKernel(KernelBase *kernel) = 0; virtual void AttachKernel(KernelBase *kernel) = 0;
...@@ -169,7 +169,7 @@ class OpLite : public Registry { ...@@ -169,7 +169,7 @@ class OpLite : public Registry {
} }
protected: protected:
lite::Scope *scope_{nullptr}; Scope *scope_{nullptr};
std::unique_ptr<KernelBase> kernel_; std::unique_ptr<KernelBase> kernel_;
std::string op_type_; std::string op_type_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h"
#include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/generate_program_pass.h"
#include "lite/core/mir/pass_manager.h" #include "lite/core/mir/pass_manager.h"
#include "lite/core/mir/pass_utils.h" #include "lite/core/mir/pass_utils.h"
...@@ -36,6 +37,9 @@ namespace lite { ...@@ -36,6 +37,9 @@ namespace lite {
* lite::Optimizer optimize a program. It utilize the mir passes to analysis the * lite::Optimizer optimize a program. It utilize the mir passes to analysis the
* program and export an optimized program. * program and export an optimized program.
*/ */
// TODO(hong1986032) Support the following passes for the subblocks
const std::set<std::string> kSubblockUnsupportedPasses(
{"memory_optimize_pass"});
class Optimizer { class Optimizer {
public: public:
Optimizer() {} Optimizer() {}
...@@ -60,14 +64,20 @@ class Optimizer { ...@@ -60,14 +64,20 @@ class Optimizer {
program_ = &program; program_ = &program;
valid_places_ = valid_places; valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set"; CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found"; CHECK(graphs_.empty()) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph); auto block_size = program.block_size();
graph_->Build(program, valid_places); for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
graph_->SetValidPlaces(valid_places); std::unique_ptr<mir::SSAGraph> graph;
graph.reset(new mir::SSAGraph);
graph->Build(program, valid_places, block_idx);
graph->SetValidPlaces(valid_places);
graphs_.emplace_back(std::move(graph));
}
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass(); InitTargetTypeTransformPass();
InitControlFlowOpUnusedInputsAndOutputsEliminatePass();
if (passes.empty() || passes.size() == 1) { if (passes.empty() || passes.size() == 1) {
std::vector<std::string> passes_local{ std::vector<std::string> passes_local{
...@@ -76,6 +86,7 @@ class Optimizer { ...@@ -76,6 +86,7 @@ class Optimizer {
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn "lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", // "lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise "lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
"lite_conv_conv_fuse_pass", //
// TODO(Superjomn) Refine the fusion related design to select fusion // TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically. // kernels for devices automatically.
"lite_conv_activation_fuse_pass", // "lite_conv_activation_fuse_pass", //
...@@ -111,6 +122,7 @@ class Optimizer { ...@@ -111,6 +122,7 @@ class Optimizer {
"apu_subgraph_pass", "apu_subgraph_pass",
"rknpu_subgraph_pass", "rknpu_subgraph_pass",
"mlu_subgraph_pass", "mlu_subgraph_pass",
"control_flow_op_unused_inputs_and_outputs_eliminate_pass",
"static_kernel_pick_pass", // pick original kernel from graph "static_kernel_pick_pass", // pick original kernel from graph
"remove_tf_redundant_ops_pass", "remove_tf_redundant_ops_pass",
...@@ -175,62 +187,15 @@ class Optimizer { ...@@ -175,62 +187,15 @@ class Optimizer {
exec_scope_ = program.exec_scope(); exec_scope_ = program.exec_scope();
} }
const lite::Scope* exec_scope() const { return exec_scope_; } const Scope* exec_scope() const { return exec_scope_; }
// Set shape(dims) infos of var descs to scope var.
// developer can write pass using input / output tensor dims of op.
//
// Example: If you have node `Node* softmax_node`,
// you can get dims of output tensor in passes:
//
// auto* scope = softmax_node->AsStmt().op()->scope();
// auto softmax_out_arg_name =
// softmax_node->outlinks.front()->AsArg().name;
// auto softmax_out_tensor =
// scope->FindVar(softmax_out_arg_name)->Get<lite::Tensor>();
// softmax_out_dims = softmax_out_tensor.dims();
void SetVarDescShapeToScopeVar() {
auto dims_to_str_func = [](std::vector<int64_t> shape) -> std::string {
std::string str_res;
for (size_t i = 0; i < shape.size(); ++i) {
str_res += std::to_string(shape[i]);
if (i != shape.size() - 1) {
str_res += "x";
}
}
return str_res;
};
auto* program_desc = program_->program_desc();
VLOG(5) << "program_desc->BlocksSize():" << program_desc->BlocksSize();
auto blocks_desc = program_desc->GetBlocks();
for (size_t bidx = 0; bidx < blocks_desc.size(); ++bidx) {
auto block_desc = blocks_desc[bidx];
auto vars_desc = block_desc.GetVars();
for (size_t vidx = 0; vidx < vars_desc.size(); ++vidx) {
auto var_desc = vars_desc[vidx];
VLOG(5) << var_desc.Name() << " "
<< dims_to_str_func(var_desc.GetShape());
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
auto* var = program_->exec_scope()->FindVar(var_desc.Name());
auto tensor = var->GetMutable<lite::Tensor>();
if (tensor->dims().size() == 0 && var_desc.GetShape().size() != 0) {
VLOG(5) << "var_desc.Name():" << var_desc.Name()
<< " shape:" << dims_to_str_func(var_desc.GetShape());
tensor->Resize(var_desc.GetShape());
}
VLOG(5) << "var_desc.Name():" << var_desc.Name()
<< " shape:" << dims_to_str_func(var_desc.GetShape())
<< " tensor:" << tensor->dims();
}
}
}
// Generate a new program based on the mir graph. // Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() { std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>( auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass"); "generate_program_pass");
pass->Apply(graph_); for (auto& graph : graphs_) {
pass->Apply(graph);
}
auto program = pass->GenProgram(); auto program = pass->GenProgram();
CHECK(exec_scope_); CHECK(exec_scope_);
program->set_exec_scope(exec_scope_); program->set_exec_scope(exec_scope_);
...@@ -246,27 +211,38 @@ class Optimizer { ...@@ -246,27 +211,38 @@ class Optimizer {
pass->SetValidPlaces(valid_places_); pass->SetValidPlaces(valid_places_);
} }
void InitControlFlowOpUnusedInputsAndOutputsEliminatePass() {
auto* pass =
mir::PassManager::Global()
.LookUp<mir::ControlFlowOpUnusedInputsAndOutputsEliminatePass>(
"control_flow_op_unused_inputs_and_outputs_eliminate_pass");
CHECK(pass);
CHECK(!graphs_.empty());
pass->SetAllGraphs(&graphs_);
}
// Generate C++ code which combines the inference program, model and weights. // Generate C++ code which combines the inference program, model and weights.
void GenCode(const std::string& code_dir); void GenCode(const std::string& code_dir);
const mir::SSAGraph& ssa_graph() const { const mir::SSAGraph& ssa_graph(int block_idx = kRootBlockIdx) const {
CHECK(graph_); CHECK(!graphs_.empty());
return *graph_; CHECK(graphs_[block_idx]);
return *graphs_[block_idx];
} }
mir::SSAGraph* mutable_ssa_graph() { mir::SSAGraph* mutable_ssa_graph(int block_idx = kRootBlockIdx) {
CHECK(graph_); CHECK(!graphs_.empty());
return graph_.get(); CHECK(graphs_[block_idx]);
return graphs_[block_idx].get();
} }
lite::Scope* exec_scope() { return exec_scope_; } Scope* exec_scope() { return exec_scope_; }
protected: protected:
void SpecifyKernelPickTactic(core::KernelPickFactor factor); void SpecifyKernelPickTactic(core::KernelPickFactor factor);
// Specify the passes and run them. // Specify the passes and run them.
void RunPasses(const std::vector<std::string>& passes) { void RunPasses(const std::vector<std::string>& passes) {
SetVarDescShapeToScopeVar();
for (auto& x : passes) { for (auto& x : passes) {
LOG(INFO) << "== Running pass: " << x; LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x); mir::Pass* pass = mir::PassManager::Global().LookUp(x);
...@@ -284,16 +260,23 @@ class Optimizer { ...@@ -284,16 +260,23 @@ class Optimizer {
LOG(INFO) << " - Skip " << x LOG(INFO) << " - Skip " << x
<< " because the target or kernel does not match."; << " because the target or kernel does not match.";
} else { } else {
pass->Apply(graph_); // Check the pass whether it is supported for processing subblocks
if (kSubblockUnsupportedPasses.count(x)) {
pass->Apply(graphs_[kRootBlockIdx]);
} else {
for (auto& graph : graphs_) {
pass->Apply(graph);
}
}
LOG(INFO) << "== Finished running: " << x; LOG(INFO) << "== Finished running: " << x;
} }
} }
} }
private: private:
std::unique_ptr<mir::SSAGraph> graph_; std::vector<std::unique_ptr<mir::SSAGraph>> graphs_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
lite::Scope* exec_scope_{}; Scope* exec_scope_{};
Program* program_{}; Program* program_{};
}; };
......
此差异已折叠。
...@@ -41,61 +41,72 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__"; ...@@ -41,61 +41,72 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__";
// - scope: which contains all the weights // - scope: which contains all the weights
struct Program { struct Program {
public: public:
explicit Program(const std::shared_ptr<Scope>& root) { scope_ = root; } explicit Program(const std::shared_ptr<Scope>& root_scope) {
Program(const cpp::ProgramDesc& desc, scope_ = root_scope;
const std::shared_ptr<Scope>& root, }
Program(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root_scope,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {}) const std::vector<std::string>& vars_to_clone = {})
: scope_(root), valid_places_(valid_places) { : scope_(root_scope),
desc_.CopyFrom(desc); valid_places_(valid_places),
program_desc_(program_desc) {
CHECK(scope_) << "scope should be init first"; CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work"; VLOG(4) << "prepare work";
PrepareWorkspace(desc, var_names); PrepareWorkspace(program_desc_, vars_to_clone);
VLOG(4) << "build desc"; VLOG(4) << "build desc";
Build(desc); Build(program_desc_);
VLOG(4) << "build desc finished"; VLOG(4) << "build desc finished";
} }
std::unique_ptr<Program> Clone() const { std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(desc_, scope_, valid_places_)); return std::unique_ptr<Program>(
return res; new Program(program_desc_, scope_, valid_places_));
} }
const std::list<std::string>& weights() const { return weights_; } const std::list<std::string>& weights() const { return weights_; }
const std::list<std::string>& tmp_vars() const { return tmp_vars_; } const std::list<std::string>& vars() const { return vars_; }
std::list<std::string>* mutable_weights() { return &weights_; } std::list<std::string>* mutable_weights() { return &weights_; }
std::list<std::string>* mutable_tmp_vars() { return &tmp_vars_; } std::list<std::string>* mutable_vars() { return &vars_; }
const std::list<std::shared_ptr<OpLite>>& ops() const { return ops_; } const std::list<std::shared_ptr<OpLite>>& ops(
std::list<std::shared_ptr<OpLite>>* mutable_ops() { return &ops_; } int block_idx = kRootBlockIdx) const {
return ops_[block_idx];
}
std::list<std::shared_ptr<OpLite>>* mutable_ops(
int block_idx = kRootBlockIdx) {
return &ops_[block_idx];
}
lite::Scope* exec_scope() { return exec_scope_; } size_t block_size() { return ops_.size(); }
lite::Scope* scope() { return scope_.get(); }
cpp::ProgramDesc* program_desc() { return &desc_; } Scope* exec_scope() { return exec_scope_; }
Scope* scope() { return scope_.get(); }
const std::map<std::string, PrecisionType>& var_data_type() const { cpp::ProgramDesc* program_desc() { return program_desc_.get(); }
return var_data_type_;
const std::map<std::string, const Type*>& var_type_map() const {
return var_type_map_;
} }
private: private:
// Build from a program and scope. // Build from a program and scope.
void Build(const cpp::ProgramDesc& program); void Build(const std::shared_ptr<cpp::ProgramDesc>& program_desc);
// Create temporary variables. // Create temporary variables.
void PrepareWorkspace(const cpp::ProgramDesc& program, void PrepareWorkspace(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::vector<std::string>& var_names = {}); const std::vector<std::string>& vars_to_clone = {});
private: private:
std::map<std::string, PrecisionType> var_data_type_; std::map<std::string, const Type*> var_type_map_;
std::list<std::string> tmp_vars_; std::list<std::string> vars_;
std::list<std::string> weights_; std::list<std::string> weights_;
std::list<std::shared_ptr<OpLite>> ops_; std::vector<std::list<std::shared_ptr<OpLite>>> ops_;
// the scope to run the kernels, NOTE this is the execution scope. // the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope_; std::shared_ptr<Scope> scope_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
// Runtime scope. // Runtime scope.
lite::Scope* exec_scope_{}; Scope* exec_scope_{};
cpp::ProgramDesc desc_; std::shared_ptr<cpp::ProgramDesc> program_desc_;
}; };
struct Instruction { struct Instruction {
...@@ -179,8 +190,22 @@ struct Instruction { ...@@ -179,8 +190,22 @@ struct Instruction {
*/ */
class LITE_API RuntimeProgram { class LITE_API RuntimeProgram {
public: public:
explicit RuntimeProgram(std::vector<Instruction>&& insts) explicit RuntimeProgram(std::vector<std::vector<Instruction>>&& insts)
: instructions_(std::move(insts)) { : instructions_(std::move(insts)) {
Init();
}
explicit RuntimeProgram(
const std::shared_ptr<const cpp::ProgramDesc>& program_desc,
Scope* exec_scope,
int block_idx = kRootBlockIdx);
~RuntimeProgram() {
#ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kCreate);
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch);
#endif // LITE_WITH_PROFILE
}
void Init() {
if (instructions_.empty()) { if (instructions_.empty()) {
LOG(FATAL) << "no instructions"; LOG(FATAL) << "no instructions";
} }
...@@ -189,7 +214,7 @@ class LITE_API RuntimeProgram { ...@@ -189,7 +214,7 @@ class LITE_API RuntimeProgram {
#endif #endif
#ifdef LITE_WITH_NVTX #ifdef LITE_WITH_NVTX
const NVTXAnnotator& annotator = NVTXAnnotator::Global(); const NVTXAnnotator& annotator = NVTXAnnotator::Global();
for (auto& inst : instructions_) { for (auto& inst : instructions_[kRootBlockIdx]) {
NVTXRangeAnnotation annotation = annotator.AnnotateBlock(); NVTXRangeAnnotation annotation = annotator.AnnotateBlock();
register_layer_names_.push_back(annotator.RegisterString( register_layer_names_.push_back(annotator.RegisterString(
const_cast<paddle::lite::OpLite*>(inst.op())->Type().c_str())); const_cast<paddle::lite::OpLite*>(inst.op())->Type().c_str()));
...@@ -197,30 +222,27 @@ class LITE_API RuntimeProgram { ...@@ -197,30 +222,27 @@ class LITE_API RuntimeProgram {
register_layer_names_.push_back(annotator.RegisterString("one_loop")); register_layer_names_.push_back(annotator.RegisterString("one_loop"));
#endif #endif
} }
~RuntimeProgram() {
#ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kCreate);
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch);
#endif // LITE_WITH_PROFILE
}
void Run(); void Run();
void set_exec_scope(lite::Scope* x) { exec_scope_ = x; } void set_exec_scope(Scope* x) { exec_scope_ = x; }
lite::Scope* exec_scope() { return exec_scope_; } Scope* exec_scope() { return exec_scope_; }
size_t num_instructions() const { return instructions_.size(); } const std::vector<Instruction>& instructions(
int block_idx = kRootBlockIdx) const {
return instructions_[block_idx];
}
const std::vector<Instruction>& instructions() const { return instructions_; } std::vector<Instruction>* mutable_instructions(
int block_idx = kRootBlockIdx) {
return &instructions_[block_idx];
}
// `SaveOpInfosToProgram` will update the op list(ops_) of the block 0 size_t block_size() { return instructions_.size(); }
// in ProgramDesc.
void SaveOpInfosToProgram(cpp::ProgramDesc* desc);
// `UpdateVarsOfProgram` will update the var list(vars_) of the block 0 in // Update the ops and vars of all of blocks to the given program_desc
// ProgramDesc. Namely, if a new var created in some passes, its var_desc will // according to the instructions
// be added in vars_. void SaveToProgram(std::shared_ptr<cpp::ProgramDesc> program_desc);
void UpdateVarsOfProgram(cpp::ProgramDesc* desc);
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
// UpdateCudaContext will update the exec stream and io stream of all kernels // UpdateCudaContext will update the exec stream and io stream of all kernels
...@@ -230,14 +252,14 @@ class LITE_API RuntimeProgram { ...@@ -230,14 +252,14 @@ class LITE_API RuntimeProgram {
private: private:
RuntimeProgram(const RuntimeProgram&) = delete; RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_; std::vector<std::vector<Instruction>> instructions_;
lite::Scope* exec_scope_{}; Scope* exec_scope_{};
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
profile::Profiler profiler_; profile::Profiler profiler_;
void set_profiler() { void set_profiler() {
for (auto i = instructions_.begin(); i != instructions_.end(); ++i) { for (auto& inst : instructions_[kRootBlockIdx]) {
i->set_profiler(&profiler_); inst.set_profiler(&profiler_);
} }
} }
#endif #endif
......
...@@ -37,7 +37,7 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -37,7 +37,7 @@ bool SubgraphEngine::BuildDeviceProgram() {
subgraph::apu::Graph graph; subgraph::apu::Graph graph;
int neuron_errCode = NeuronModel_create(&model_); int neuron_errCode = NeuronModel_create(&model_);
if (NEURON_NO_ERROR != neuron_errCode) { if (NEURON_NO_ERROR != neuron_errCode) {
LOG(WARNING) << "Fail to create model"; LOG(WARNING) << "[APU] Failed to create the neuron model!";
return false; return false;
} }
graph.set_model(model_); graph.set_model(model_);
...@@ -46,11 +46,12 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -46,11 +46,12 @@ bool SubgraphEngine::BuildDeviceProgram() {
// Convert all of ops and their input vars and weights and added into the APU // Convert all of ops and their input vars and weights and added into the APU
// NIR graph // NIR graph
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) { const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = const_cast<OpLite*>(inst.op()); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
...@@ -70,55 +71,38 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -70,55 +71,38 @@ bool SubgraphEngine::BuildDeviceProgram() {
} }
} }
// Get input tensor // Get the index of input tensors
std::vector<uint32_t> ins; std::vector<uint32_t> input_indices;
origin_itensors_.resize(input_names_.size());
origin_idims_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) { for (int i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); CHECK(graph.Has(input_names_[i])) << "[APU] Failed to find input node "
CHECK(origin_itensors_[i]); << input_names_[i];
origin_idims_[i] = origin_itensors_[i]->dims(); auto index = graph.Get(input_names_[i])->index();
VLOG(3) << "subgraph input name: " << i << ", " << input_names_[i] << ":" input_indices.push_back(index);
<< origin_idims_[i].production(); VLOG(3) << "[APU] Input[" << i << "] name " << input_names_[i] << " dims "
// Get input index << origin_itensors_[i]->dims() << " index " << index;
int idx;
if (graph.Has(input_names_[i])) {
ins.push_back(graph.Get(input_names_[i])->index());
VLOG(3) << "input idx: " << graph.Get(input_names_[i])->index();
} else {
LOG(WARNING) << "Fail to find input: " << input_names_[i];
return false;
}
} }
// Get output tensor // Get the index of output tensors
std::vector<uint32_t> outs; std::vector<uint32_t> output_indices;
origin_otensors_.resize(output_names_.size());
origin_odims_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) { for (int i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); CHECK(graph.Has(output_names_[i])) << "[APU] Failed to find output node "
CHECK(origin_otensors_[i]); << output_names_[i];
origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "subgraph output name: " << i << ", " << output_names_[i] << ":"
<< origin_odims_[i].production();
origin_otensors_[i]->mutable_data<int8_t>(); origin_otensors_[i]->mutable_data<int8_t>();
// Get input index auto index = graph.Get(output_names_[i])->index();
if (graph.Has(output_names_[i])) { output_indices.push_back(index);
outs.push_back(graph.Get(output_names_[i])->index()); VLOG(3) << "[APU] Output[" << i << "] name " << output_names_[i] << " dims "
VLOG(3) << "output idx: " << graph.Get(output_names_[i])->index(); << origin_otensors_[i]->dims() << " index " << index;
} else { }
LOG(WARNING) << "Fail to find output: " << output_names_[i];
return false; // Indentify the input and output tensors of the neuron model
} NeuronModel_identifyInputsAndOutputs(model_,
} input_indices.size(),
&input_indices[0],
VLOG(3) << "ins size: " << ins.size() << " outs size:" << outs.size(); output_indices.size(),
// Set subgraph input/output &output_indices[0]);
NeuronModel_identifyInputsAndOutputs(
model_, ins.size(), &ins[0], outs.size(), &outs[0]);
neuron_errCode = NeuronModel_finish(model_); neuron_errCode = NeuronModel_finish(model_);
if (NEURON_NO_ERROR != neuron_errCode) { if (NEURON_NO_ERROR != neuron_errCode) {
LOG(WARNING) << "Fail to create NIR model:" << neuron_errCode; LOG(WARNING) << "[APU] Fail to create NIR model:" << neuron_errCode;
return false; return false;
} }
VLOG(3) << "[APU] APU NIR model created!"; VLOG(3) << "[APU] APU NIR model created!";
...@@ -207,11 +191,11 @@ SubgraphEngine::~SubgraphEngine() { ...@@ -207,11 +191,11 @@ SubgraphEngine::~SubgraphEngine() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(), engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx, param.block_idx,
param.sub_block_desc, param.program_desc,
param.exec_scope,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names));
param.scope));
CHECK(engine_); CHECK(engine_);
} }
......
...@@ -31,12 +31,16 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -31,12 +31,16 @@ class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext *ctx, SubgraphEngine(KernelContext *ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, const std::shared_ptr<const cpp::ProgramDesc> &program_desc,
Scope *exec_scope,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names)
Scope *scope) : subgraph::Engine(ctx,
: subgraph::Engine( block_idx,
ctx, block_idx, block_desc, input_names, output_names, scope) {} program_desc,
exec_scope,
input_names,
output_names) {}
~SubgraphEngine(); ~SubgraphEngine();
......
...@@ -75,7 +75,6 @@ add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_comp ...@@ -75,7 +75,6 @@ add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_comp
add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -87,7 +86,6 @@ add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_comp ...@@ -87,7 +86,6 @@ add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_comp
add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/program.h"
#include "lite/operators/conditional_block_op.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h"
#include "lite/core/profile/precision_profiler.h"
#include "lite/core/profile/profiler.h"
#endif
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class CondExecutor {
typedef std::shared_ptr<OpLite> OpPtr;
public:
CondExecutor(cpp::BlockDesc *block, Scope *scope, Place place)
: scope_(scope), place_(place) {
int32_t op_size = block->OpsSize();
for (int32_t i = 0; i < op_size; ++i) {
auto &op_desc = *block->template GetOp<cpp::OpDesc>(i);
auto op_type = op_desc.Type();
auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type());
op_handler->Attach(op_desc, scope);
auto hostplace = place_;
hostplace.target = TARGET(kHost);
auto kernels = op_handler->CreateKernels({place_, hostplace});
CHECK_GT(kernels.size(), 0) << "cannot create kernel";
op_handler->AttachKernel(kernels[0].get());
op_handler->SetKernel(kernels);
ops_of_block_.push_back(op_handler);
}
}
void Run() {
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
lite::profile::Profiler profiler;
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
for (auto &op_handler : ops_of_block_) {
op_handler->CheckShape();
op_handler->InferShape();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
std::unique_ptr<KernelBase> kernel(op_handler->GetKernel());
Instruction inst(op_handler, std::move(kernel));
inst.set_profiler(&profiler);
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
op_handler->Run();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
LITE_PRECISION_PROFILE(inst)
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
}
}
private:
Scope *scope_;
Place place_;
std::vector<OpPtr> ops_of_block_;
};
class ConditionalBlockCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ConditionalBlockParam;
void PrepareForRun() override;
void Run() override;
virtual ~ConditionalBlockCompute() = default;
private:
std::shared_ptr<CondExecutor> executor_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -202,17 +202,13 @@ void ElementwiseMulCompute<T, PType>::Run() { ...@@ -202,17 +202,13 @@ void ElementwiseMulCompute<T, PType>::Run() {
} }
} }
template <> template <typename T, PrecisionType PType>
void ElementwiseMulCompute<int64_t, PRECISION(kInt64)>::Run() { void ElementwiseMulActivationCompute<T, PType>::Run() {
auto& param = this->template Param<operators::ElementwiseParam>(); auto& param =
lite::arm::math::elementwise_compute_basic<int64_t>(param, "mul", ""); this->template Param<operators::FusionElementwiseActivationParam>();
} auto* x_data = param.X->template data<T>();
auto* y_data = param.Y->template data<T>();
void ElementwiseMulActivationCompute::Run() { auto* out_data = param.Out->template mutable_data<T>();
auto& param = Param<operators::FusionElementwiseActivationParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int axis = param.axis; int axis = param.axis;
std::string act_type = param.act_type; std::string act_type = param.act_type;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
...@@ -221,21 +217,21 @@ void ElementwiseMulActivationCompute::Run() { ...@@ -221,21 +217,21 @@ void ElementwiseMulActivationCompute::Run() {
if (x_dims.size() < y_dims.size() && if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_mul_relu_broadcast<float>( lite::arm::math::elementwise_mul_relu_broadcast<T>(
y_data, x_data, out_data, pre, n, post); y_data, x_data, out_data, pre, n, post);
} else { } else {
LOG(FATAL) << "unsupported Activation type: " << act_type; LOG(FATAL) << "unsupported Activation type: " << act_type;
} }
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_mul_relu_broadcast( lite::arm::math::elementwise_mul_relu_broadcast<T>(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
} else { } else {
LOG(FATAL) << "unsupported Activation type: " << act_type; LOG(FATAL) << "unsupported Activation type: " << act_type;
} }
} else { } else {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_mul_relu( lite::arm::math::elementwise_mul_relu<T>(
x_data, y_data, out_data, x_dims.production()); x_data, y_data, out_data, x_dims.production());
} else { } else {
LOG(FATAL) << "unsupported Activation type: " << act_type; LOG(FATAL) << "unsupported Activation type: " << act_type;
...@@ -426,46 +422,60 @@ REGISTER_LITE_KERNEL( ...@@ -426,46 +422,60 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_mul_float = using elementwise_mul_float_t =
paddle::lite::kernels::arm::ElementwiseMulCompute<float, PRECISION(kFloat)>; paddle::lite::kernels::arm::ElementwiseMulCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float, def) elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_mul_int32 = using elementwise_mul_int32_t =
paddle::lite::kernels::arm::ElementwiseMulCompute<int, PRECISION(kInt32)>; paddle::lite::kernels::arm::ElementwiseMulCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32, def) elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize(); .Finalize();
using elementwise_mul_int64 = using elementwise_mul_int64_t =
paddle::lite::kernels::arm::ElementwiseMulCompute<int64_t, paddle::lite::kernels::arm::ElementwiseMulCompute<int64_t,
PRECISION(kInt64)>; PRECISION(kInt64)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64, def) elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( using fusion_elementwise_mul_activation_float_t = paddle::lite::kernels::arm::
fusion_elementwise_mul_activation, ElementwiseMulActivationCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
kARM, kARM,
kFloat, kFloat,
kNCHW, kNCHW,
paddle::lite::kernels::arm::ElementwiseMulActivationCompute, fusion_elementwise_mul_activation_float_t,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using fusion_elementwise_mul_activation_int64_t = paddle::lite::kernels::arm::
ElementwiseMulActivationCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
kARM,
kInt64,
kNCHW,
fusion_elementwise_mul_activation_int64_t,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_max, REGISTER_LITE_KERNEL(elementwise_max,
kARM, kARM,
kFloat, kFloat,
...@@ -489,22 +499,22 @@ REGISTER_LITE_KERNEL( ...@@ -489,22 +499,22 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_div_fp32 = using elementwise_div_fp32_t =
paddle::lite::kernels::arm::ElementwiseDivCompute<float, PRECISION(kFloat)>; paddle::lite::kernels::arm::ElementwiseDivCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32, def) elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_div_int64 = using elementwise_div_int64_t =
paddle::lite::kernels::arm::ElementwiseDivCompute<int64_t, paddle::lite::kernels::arm::ElementwiseDivCompute<int64_t,
PRECISION(kInt64)>; PRECISION(kInt64)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64, def) elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
...@@ -522,11 +532,11 @@ REGISTER_LITE_KERNEL( ...@@ -522,11 +532,11 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_mod_int64 = using elementwise_mod_int64_t =
paddle::lite::kernels::arm::ElementwiseModCompute<int64_t, paddle::lite::kernels::arm::ElementwiseModCompute<int64_t,
PRECISION(kInt64)>; PRECISION(kInt64)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mod, kARM, kInt64, kNCHW, elementwise_mod_int64, def) elementwise_mod, kARM, kInt64, kNCHW, elementwise_mod_int64_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
......
...@@ -62,8 +62,8 @@ class ElementwiseMulCompute : public KernelLite<TARGET(kARM), PType> { ...@@ -62,8 +62,8 @@ class ElementwiseMulCompute : public KernelLite<TARGET(kARM), PType> {
virtual ~ElementwiseMulCompute() = default; virtual ~ElementwiseMulCompute() = default;
}; };
class ElementwiseMulActivationCompute template <typename T, PrecisionType PType>
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class ElementwiseMulActivationCompute : public KernelLite<TARGET(kARM), PType> {
public: public:
void Run() override; void Run() override;
......
...@@ -533,13 +533,15 @@ TEST(fusion_elementwise_mul_activation_arm, retrive_op) { ...@@ -533,13 +533,15 @@ TEST(fusion_elementwise_mul_activation_arm, retrive_op) {
} }
TEST(fusion_elementwise_mul_activation_arm, init) { TEST(fusion_elementwise_mul_activation_arm, init) {
ElementwiseMulActivationCompute fusion_elementwise_mul_activation; ElementwiseMulActivationCompute<float, PRECISION(kFloat)>
fusion_elementwise_mul_activation;
ASSERT_EQ(fusion_elementwise_mul_activation.precision(), PRECISION(kFloat)); ASSERT_EQ(fusion_elementwise_mul_activation.precision(), PRECISION(kFloat));
ASSERT_EQ(fusion_elementwise_mul_activation.target(), TARGET(kARM)); ASSERT_EQ(fusion_elementwise_mul_activation.target(), TARGET(kARM));
} }
TEST(fusion_elementwise_mul_activation_arm, compute) { TEST(fusion_elementwise_mul_activation_arm, compute) {
ElementwiseMulActivationCompute fusion_elementwise_mul_activation; ElementwiseMulActivationCompute<float, PRECISION(kFloat)>
fusion_elementwise_mul_activation;
operators::FusionElementwiseActivationParam param; operators::FusionElementwiseActivationParam param;
lite::Tensor x, y, output, output_ref; lite::Tensor x, y, output, output_ref;
......
...@@ -20,44 +20,45 @@ namespace lite { ...@@ -20,44 +20,45 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename T> template <typename IndexType, typename DataType>
void GatherFunc(const operators::GatherParam& param) { void GatherFunc(const operators::GatherParam& param) {
auto src_dims = param.X->dims(); auto src_dims = param.X->dims();
auto index_size = param.Index->dims()[0]; auto index_size = param.Index->dims()[0];
auto* p_src = param.X->data<T>(); auto* p_src = param.X->data<DataType>();
const int* p_index = param.Index->data<int>(); const IndexType* p_index = param.Index->data<IndexType>();
auto* p_output = param.Out->mutable_data<T>(); auto* p_output = param.Out->mutable_data<DataType>();
int slice_size = 1; int slice_size = 1;
for (size_t i = 1; i < src_dims.size(); ++i) { for (size_t i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i]; slice_size *= src_dims[i];
} }
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
int index_ = p_index[i]; IndexType index_ = p_index[i];
memcpy(p_output + i * slice_size, memcpy(p_output + i * slice_size,
p_src + index_ * slice_size, p_src + index_ * slice_size,
slice_size * sizeof(T)); slice_size * sizeof(DataType));
} }
} }
void GatherCompute::Run() { template <typename IndexType>
auto& param = this->Param<operators::GatherParam>(); void GatherCompute<IndexType>::Run() {
auto& param = this->template Param<operators::GatherParam>();
switch (param.X->precision()) { switch (param.X->precision()) {
case PRECISION(kFloat): case PRECISION(kFloat):
GatherFunc<float>(param); GatherFunc<IndexType, float>(param);
break; break;
case PRECISION(kInt8): case PRECISION(kInt8):
GatherFunc<int8_t>(param); GatherFunc<IndexType, int8_t>(param);
break; break;
case PRECISION(kInt16): case PRECISION(kInt16):
GatherFunc<int16_t>(param); GatherFunc<IndexType, int16_t>(param);
break; break;
case PRECISION(kInt32): case PRECISION(kInt32):
GatherFunc<int32_t>(param); GatherFunc<IndexType, int32_t>(param);
break; break;
case PRECISION(kInt64): case PRECISION(kInt64):
GatherFunc<int64_t>(param); GatherFunc<IndexType, int64_t>(param);
break; break;
default: default:
LOG(FATAL) << "Gather does not implement for the " LOG(FATAL) << "Gather does not implement for the "
...@@ -70,9 +71,26 @@ void GatherCompute::Run() { ...@@ -70,9 +71,26 @@ void GatherCompute::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(gather,
gather, kARM, kAny, kNCHW, paddle::lite::kernels::arm::GatherCompute, def) kARM,
kAny,
kNCHW,
paddle::lite::kernels::arm::GatherCompute<int32_t>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(gather,
kARM,
kAny,
kNCHW,
paddle::lite::kernels::arm::GatherCompute<int64_t>,
def_int64_idx)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
...@@ -23,6 +23,7 @@ namespace lite { ...@@ -23,6 +23,7 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename IndexType>
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public: public:
void Run() override; void Run() override;
......
...@@ -102,10 +102,14 @@ void SequenceConvCompute::Run() { ...@@ -102,10 +102,14 @@ void SequenceConvCompute::Run() {
1, 1,
1, // stride_h, stride_w, dilation_h, dilation_w 1, // stride_h, stride_w, dilation_h, dilation_w
tmp_data); tmp_data);
local_naive_transpose(tmp_data, int cols = kernel_size * hidden_dim;
sub_col_data, int rows = input_row_end - input_row_begin;
kernel_size * hidden_dim, if (cols % 4 == 0 && rows % 4 == 0) {
input_row_end - input_row_begin); paddle::lite::arm::math::local_transpose(
tmp_data, sub_col_data, cols, rows);
} else {
local_naive_transpose(tmp_data, sub_col_data, cols, rows);
}
} }
} }
......
...@@ -28,36 +28,17 @@ namespace lite { ...@@ -28,36 +28,17 @@ namespace lite {
namespace kernels { namespace kernels {
namespace bm { namespace bm {
bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() {
// Obtain the origin input tensors, and create the origin output
// tensors(Don't try to access them before launch the device program or the
// origin program)
PrepareWorkspaceForOriginProgram();
// Create the device input and output tensors, but don't initialize them
// with the dimensions
device_inputs_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) {
device_inputs_[i].reset(new hiai::AiTensor);
CHECK(device_inputs_[i]);
}
device_outputs_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) {
device_outputs_[i].reset(new hiai::AiTensor);
CHECK(device_outputs_[i]);
}
return true;
}
bool SubgraphEngine::BuildDeviceProgram() { bool SubgraphEngine::BuildDeviceProgram() {
int status = 0; int status = 0;
subgraph::bm::Graph graph; subgraph::bm::Graph graph;
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
graph.CreateCompilerHandle(); graph.CreateCompilerHandle();
auto& ctx = this->ctx_->template As<BMContext>(); auto& ctx = this->ctx_->template As<BMContext>();
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
for (auto& inst : origin_program_) { const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = const_cast<OpLite*>(inst.op()); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
...@@ -93,13 +74,11 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -93,13 +74,11 @@ bool SubgraphEngine::BuildDeviceProgram() {
net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]); net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]);
auto& stage = net_info_->stages[0]; auto& stage = net_info_->stages[0];
// input // input
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
device_inputs_.resize(input_names_.size()); device_inputs_.resize(input_names_.size());
for (size_t i = 0; i < input_names_.size(); i++) { for (size_t i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(net_info_->input_names[i]); origin_itensors_[i] =
exec_scope_->FindMutableTensor(net_info_->input_names[i]);
CHECK(origin_itensors_[i]); CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
bm_device_mem_t* p_mem = bm_device_mem_t* p_mem =
static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t))); static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
CHECK(p_mem != nullptr); CHECK(p_mem != nullptr);
...@@ -112,8 +91,6 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -112,8 +91,6 @@ bool SubgraphEngine::BuildDeviceProgram() {
stage.input_shapes[i]); stage.input_shapes[i]);
} }
// output // output
origin_odims_.resize(output_names_.size());
origin_otensors_.resize(output_names_.size());
device_outputs_.resize(net_info_->output_num); device_outputs_.resize(net_info_->output_num);
int out_index = 0; int out_index = 0;
for (int i = 0; i < output_names_.size(); i++) { for (int i = 0; i < output_names_.size(); i++) {
...@@ -121,14 +98,13 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -121,14 +98,13 @@ bool SubgraphEngine::BuildDeviceProgram() {
} }
for (int i = 0; i < net_info_->output_num; i++) { for (int i = 0; i < net_info_->output_num; i++) {
Tensor* t_cur = scope_->FindMutableTensor(net_info_->output_names[i]); Tensor* t_cur = exec_scope_->FindMutableTensor(net_info_->output_names[i]);
CHECK(t_cur != nullptr); CHECK(t_cur != nullptr);
bm_device_mem_t* p_mem = bm_device_mem_t* p_mem =
static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t))); static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
CHECK(p_mem != nullptr); CHECK(p_mem != nullptr);
if (outname_map_.find(net_info_->output_names[i]) != outname_map_.end()) { if (outname_map_.find(net_info_->output_names[i]) != outname_map_.end()) {
origin_otensors_[out_index] = t_cur; origin_otensors_[out_index] = t_cur;
origin_odims_[out_index] = origin_otensors_[out_index]->dims();
origin_otensors_[out_index]->mutable_data<float>(); origin_otensors_[out_index]->mutable_data<float>();
out_index += 1; out_index += 1;
} }
...@@ -173,11 +149,11 @@ bool SubgraphEngine::LaunchDeviceProgram() { ...@@ -173,11 +149,11 @@ bool SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(), engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx, param.block_idx,
param.sub_block_desc, param.program_desc,
param.exec_scope,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names));
param.scope));
CHECK(engine_); CHECK(engine_);
} }
......
...@@ -36,15 +36,18 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -36,15 +36,18 @@ class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext *ctx, SubgraphEngine(KernelContext *ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, const std::shared_ptr<const cpp::ProgramDesc> &program_desc,
Scope *exec_scope,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names)
Scope *scope) : subgraph::Engine(ctx,
: subgraph::Engine( block_idx,
ctx, block_idx, block_desc, input_names, output_names, scope) {} program_desc,
exec_scope,
input_names,
output_names) {}
protected: protected:
bool PrepareWorkspaceForDeviceProgram() override;
bool BuildDeviceProgram() override; bool BuildDeviceProgram() override;
bool LaunchDeviceProgram() override; bool LaunchDeviceProgram() override;
......
...@@ -18,6 +18,9 @@ add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute. ...@@ -18,6 +18,9 @@ add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.
add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps}) add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps})
add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps}) add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps})
add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps}) add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps})
add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kernel_deps})
add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps}) add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
......
...@@ -51,3 +51,19 @@ REGISTER_LITE_KERNEL( ...@@ -51,3 +51,19 @@ REGISTER_LITE_KERNEL(
PRECISION(kAny), PRECISION(kAny),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(assign,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::AssignCompute,
def_tensor_array)
.BindInput("X",
{LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
...@@ -12,28 +12,21 @@ ...@@ -12,28 +12,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/conditional_block_compute.h" #include "lite/kernels/host/conditional_block_compute.h"
#include <memory>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void ConditionalBlockCompute::PrepareForRun() { void ConditionalBlockCompute::PrepareForRun() {
auto& param = Param<operators::ConditionalBlockParam>(); auto& param = this->Param<param_t>();
auto cur_scope = param.scope; program_.reset(new RuntimeProgram(
param.program_desc, param.exec_scope, param.block_idx));
executor_ =
std::make_shared<CondExecutor>(param.sub_block, cur_scope, place());
} }
void ConditionalBlockCompute::Run() { void ConditionalBlockCompute::Run() {
auto& param = Param<operators::ConditionalBlockParam>(); auto& param = this->Param<param_t>();
for (auto& out : param.outs) { for (auto& out : param.outs) {
out->clear(); out->clear();
} }
...@@ -43,32 +36,40 @@ void ConditionalBlockCompute::Run() { ...@@ -43,32 +36,40 @@ void ConditionalBlockCompute::Run() {
auto* cond_data = cond->data<bool>(); auto* cond_data = cond->data<bool>();
need_run = cond_data[0]; need_run = cond_data[0];
} else { } else {
auto x = param.x; for (auto input : param.inputs) {
for (auto pt : x) { if (input == nullptr || !input->IsInitialized() ||
if (pt == nullptr || !pt->IsInitialized() || pt->dims().empty()) { input->dims().empty()) {
need_run = false; need_run = false;
break; break;
} }
} }
} }
if (need_run) { if (need_run) {
executor_->Run(); program_->Run();
} }
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(conditional_block, REGISTER_LITE_KERNEL(conditional_block,
kARM, kHost,
kFloat, kAny,
kNCHW, kAny,
paddle::lite::kernels::arm::ConditionalBlockCompute, paddle::lite::kernels::host::ConditionalBlockCompute,
def) def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input",
.BindInput("Cond", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) {LiteType::GetTensorListTy(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Scope", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Cond",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorListTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Scope",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize(); .Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/program.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class ConditionalBlockCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::ConditionalBlockParam;
void PrepareForRun() override;
void Run() override;
private:
std::unique_ptr<RuntimeProgram> program_;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/print_compute.h"
#include <mutex> // NOLINT
#include <string>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
const char kForward[] = "FORWARD";
const char kBackward[] = "BACKWARD";
const char kBoth[] = "BOTH";
class TensorFormatter {
public:
TensorFormatter() {}
std::string Format(const Tensor& print_tensor,
const std::string& tensor_name = "",
const std::string& message = "") {
std::stringstream log_stream;
if (!tensor_name.empty()) {
log_stream << "Variable: " << tensor_name << std::endl;
}
if (!message.empty()) {
log_stream << " - message: " << message << std::endl;
}
if (print_tensor_lod_) {
log_stream << " - lod: {";
const LoD& lod = print_tensor.lod();
for (auto level : lod) {
log_stream << "{";
bool is_first = true;
for (auto i : level) {
if (is_first) {
log_stream << i;
is_first = false;
} else {
log_stream << ", " << i;
}
}
log_stream << "}";
}
log_stream << "}" << std::endl;
}
log_stream << " - place: " << TargetToStr(print_tensor.target())
<< std::endl; // TODO(hong19860320) always kHost
if (print_tensor_shape_) {
log_stream << " - shape: " << print_tensor.dims().repr() << std::endl;
}
if (print_tensor_layout_) {
log_stream << " - layout: "
<< DataLayoutToStr(
DATALAYOUT(kNCHW)) // TODO(hong19860320) Query the data
// layout from target tensor
<< std::endl;
}
auto dtype = print_tensor.precision();
if (print_tensor_type_) {
log_stream << " - dtype: " << PrecisionToStr(dtype) << std::endl;
}
if (dtype == PRECISION(kBool)) {
FormatData<bool>(print_tensor, log_stream);
} else if (dtype == PRECISION(kInt8)) {
FormatData<int8_t>(print_tensor, log_stream);
} else if (dtype == PRECISION(kInt16)) {
FormatData<int16_t>(print_tensor, log_stream);
} else if (dtype == PRECISION(kInt32)) {
FormatData<int32_t>(print_tensor, log_stream);
} else if (dtype == PRECISION(kInt64)) {
FormatData<int64_t>(print_tensor, log_stream);
} else if (dtype == PRECISION(kFloat)) {
FormatData<float>(print_tensor, log_stream);
} else {
log_stream << "\tdata: unprintable type: " << PrecisionToStr(dtype)
<< std::endl;
}
return log_stream.str();
}
void Print(const Tensor& print_tensor,
const std::string& tensor_name = "",
const std::string& message = "") {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::cout << Format(print_tensor, tensor_name, message);
}
void SetPrintTensorType(bool print_tensor_type) {
print_tensor_type_ = print_tensor_type;
}
void SetPrintTensorShape(bool print_tensor_shape) {
print_tensor_shape_ = print_tensor_shape;
}
void SetPrintTensorLod(bool print_tensor_lod) {
print_tensor_lod_ = print_tensor_lod;
}
void SetPrintTensorLayout(bool print_tensor_layout) {
print_tensor_layout_ = print_tensor_layout;
}
void SetSummarize(int64_t summarize) { summarize_ = summarize; }
private:
template <typename T>
void FormatData(const Tensor& print_tensor, std::stringstream& log_stream) {
int64_t print_size = summarize_ == -1
? print_tensor.numel()
: std::min(summarize_, print_tensor.numel());
const T* data = print_tensor.data<T>(); // Always kHost, so unnessary to
// copy the data from device
log_stream << " - data: [";
if (print_size > 0) {
log_stream << data[0];
for (int64_t i = 1; i < print_size; ++i) {
log_stream << " " << data[i];
}
}
log_stream << "]" << std::endl;
}
int64_t summarize_ = -1;
bool print_tensor_type_ = true;
bool print_tensor_shape_ = true;
bool print_tensor_lod_ = true;
bool print_tensor_layout_ = true;
};
void PrintCompute::Run() {
auto& param = Param<param_t>();
param.out->CopyDataFrom(*param.in);
if ((param.is_forward && param.print_phase == kBackward) ||
(!param.is_forward && param.print_phase == kForward)) {
return;
}
int first_n = param.first_n;
if (first_n > 0 && ++times_ > first_n) return;
TensorFormatter formatter;
const std::string& name = param.print_tensor_name ? param.name : "";
formatter.SetPrintTensorType(param.print_tensor_type);
formatter.SetPrintTensorShape(param.print_tensor_shape);
formatter.SetPrintTensorLod(param.print_tensor_lod);
formatter.SetPrintTensorLayout(param.print_tensor_layout);
formatter.SetSummarize(static_cast<int64_t>(param.summarize));
formatter.Print(*param.in, name, param.message);
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
print, kHost, kAny, kAny, paddle::lite::kernels::host::PrintCompute, def)
.BindInput("In",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class PrintCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::PrintParam;
void Run() override;
virtual ~PrintCompute() = default;
private:
mutable int times_{0};
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册