diff --git a/CMakeLists.txt b/CMakeLists.txt index 55375994031850d93caa89ec7050a9e8e657d04f..2d3a643ea0ec27510fcd3eda5b146ce784f26cfd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,7 @@ lite_option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OF lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF) lite_option(LITE_WITH_LOG "Enable log printing or not." ON) +lite_option(LITE_WITH_EXCEPTION "Enable throwing the exception when error occurs in lite" OFF) lite_option(LITE_WITH_NVTX "Enable nvtx or not, please enable LITE_WITH_CUDA first." OFF) lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF) lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 1b0890e0dbf5e741176c293a059d809752c72a43..0e77e26bc0d8994fe5fc36b3a1f3d99b7fffa7cf 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -190,6 +190,10 @@ if (LITE_WITH_LOG) add_definitions("-DLITE_WITH_LOG") endif() +if (LITE_WITH_EXCEPTION) + add_definitions("-DLITE_WITH_EXCEPTION") +endif() + if (LITE_ON_TINY_PUBLISH) add_definitions("-DLITE_ON_TINY_PUBLISH") endif() diff --git a/cmake/cross_compiling/android.cmake b/cmake/cross_compiling/android.cmake index e6193e0bb3c93292d2264501fc4d5739ff8766ee..b89eed64a85a74899190068602e20ba982225085 100644 --- a/cmake/cross_compiling/android.cmake +++ b/cmake/cross_compiling/android.cmake @@ -80,6 +80,17 @@ if (ARM_TARGET_LANG STREQUAL "clang") elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7") set(triple arm-v7a-linux-android) set(LITE_WITH_OPENMP OFF CACHE STRING "Due to libomp's bug(For ARM64, it has been fixed by https://reviews.llvm.org/D19879, but still exists on ARM32), disable OpenMP on armv7 when cross-compiling using Clang" FORCE) + if(ANDROID_STL_TYPE MATCHES "^c\\+\\+_") + # Use CMAKE_CXX_STANDARD_LIBRARIES_INIT to ensure libunwind and libc++ is linked in the right order + set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libunwind.a") + if(ANDROID_STL_TYPE STREQUAL "c++_shared") + set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libc++_shared.so") + elseif(ANDROID_STL_TYPE STREQUAL "c++_static") + set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libc++_static.a") + else() + message(FATAL_ERROR "Invalid Android STL TYPE: ${ANDROID_STL_TYPE}.") + endif() + endif() else() message(FATAL_ERROR "Clang do not support this ${ARM_TARGET_ARCH_ABI}, use armv8 or armv7") endif() diff --git a/cmake/cross_compiling/postproject.cmake b/cmake/cross_compiling/postproject.cmake index 069923c779fbd3eed4f5f81ef3e386ff70fac215..c9c3fc9f2681b6002567d555a26ee14edefaeae5 100644 --- a/cmake/cross_compiling/postproject.cmake +++ b/cmake/cross_compiling/postproject.cmake @@ -23,6 +23,21 @@ if(ANDROID) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -llog -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -llog -fPIC") + + # Don't re-export libgcc symbols + set(REMOVE_ATOMIC_GCC_SYMBOLS "-Wl,--exclude-libs,libatomic.a -Wl,--exclude-libs,libgcc.a") + set(CMAKE_SHARED_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_SHARED_LINKER_FLAGS}") + set(CMAKE_MODULE_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_MODULE_LINKER_FLAGS}") + set(CMAKE_EXE_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_EXE_LINKER_FLAGS}") + + # Only the libunwind.a from clang(with libc++) provide C++ exception handling support for 32-bit ARM + # Refer to https://android.googlesource.com/platform/ndk/+/master/docs/BuildSystemMaintainers.md#Unwinding + if (ARM_TARGET_LANG STREQUAL "clang" AND ARM_TARGET_ARCH_ABI STREQUAL "armv7" AND ANDROID_STL_TYPE MATCHES "^c\\+\\+_") + set(REMOVE_UNWIND_SYMBOLS "-Wl,--exclude-libs,libunwind.a") + set(CMAKE_SHARED_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_SHARED_LINKER_FLAGS}") + set(CMAKE_MODULE_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_MODULE_LINKER_FLAGS}") + set(CMAKE_EXE_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_EXE_LINKER_FLAGS}") + endif() endif() if(ARMLINUX) @@ -59,14 +74,13 @@ function(check_linker_flag) endfunction() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +if((LITE_WITH_OPENCL AND (ARM_TARGET_LANG STREQUAL "clang")) OR LITE_WITH_PYTHON OR LITE_WITH_EXCEPTION OR (NOT LITE_ON_TINY_PUBLISH)) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fasynchronous-unwind-tables -funwind-tables") +else () + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -fno-asynchronous-unwind-tables -fno-unwind-tables") +endif() if (LITE_ON_TINY_PUBLISH) - if((NOT LITE_WITH_PYTHON)) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions") - endif() - if(LITE_WITH_OPENCL AND (ARM_TARGET_LANG STREQUAL "clang")) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions") - endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer -fno-asynchronous-unwind-tables -fno-unwind-tables") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -ffunction-sections") check_linker_flag(-Wl,--gc-sections) endif() diff --git a/cmake/device/npu.cmake b/cmake/device/npu.cmake index 88598f4690a157b20ac1873d84ad13c2f8652725..0409b6a60fc651cbaade61998a09bc0489bc978c 100644 --- a/cmake/device/npu.cmake +++ b/cmake/device/npu.cmake @@ -54,6 +54,11 @@ find_library(NPU_DDK_IR_BUILD_FILE NAMES hiai_ir_build PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH} NO_DEFAULT_PATH) +# Added in HiAI DDK 320 or later version +find_library(NPU_DDK_HCL_FILE NAMES hcl + PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH} + NO_DEFAULT_PATH) + if(NOT NPU_DDK_HIAI_FILE) message(FATAL_ERROR "Can not find NPU_DDK_HIAI_FILE in ${NPU_DDK_ROOT}") else() @@ -78,5 +83,13 @@ else() set_property(TARGET npu_ddk_ir_build PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_BUILD_FILE}) endif() -set(npu_runtime_libs npu_ddk_hiai CACHE INTERNAL "npu ddk runtime libs") +if(NOT NPU_DDK_HCL_FILE) +# message(FATAL_ERROR "Can not find NPU_DDK_HCL_FILE in ${NPU_DDK_ROOT}") +else() + message(STATUS "Found NPU_DDK HCL Library: ${NPU_DDK_HCL_FILE}") + add_library(npu_ddk_hcl SHARED IMPORTED GLOBAL) + set_property(TARGET npu_ddk_hcl PROPERTY IMPORTED_LOCATION ${NPU_DDK_HCL_FILE}) +endif() + +set(npu_runtime_libs npu_ddk_hiai npu_ddk_hcl CACHE INTERNAL "npu ddk runtime libs") set(npu_builder_libs npu_ddk_ir npu_ddk_ir_build CACHE INTERNAL "npu ddk builder libs") diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index eeea3b3adf4caf2e3ea57eb365c32f24626851e6..828d189c69f21fb46ecd3b3850e9c5f973e81f2b 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -45,6 +45,7 @@ if (WITH_TESTING) lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "MobileNetV1_quant.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "transformer_with_mask_fp32.tar.gz") endif() if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz") diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index dd2fd1ed23fa58e6f8de7b65294a6fc62a3bfcce..46ef543d246da4254976db82849820481a5792c7 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -37,8 +37,7 @@ void Predictor::SaveModel(const std::string &dir, if (!program_) { GenRuntimeProgram(); } - program_->SaveOpInfosToProgram(program_desc_.get()); - program_->UpdateVarsOfProgram(program_desc_.get()); + program_->SaveToProgram(program_desc_); switch (model_type) { case lite_api::LiteModelType::kProtobuf: SaveModelPb(dir, *program_->exec_scope(), *program_desc_.get(), true); @@ -58,17 +57,21 @@ void Predictor::SaveModel(const std::string &dir, void Predictor::SaveOpKernelInfo(const std::string &model_dir) { std::set ops_info; std::set kernels_info; - const auto &instructions_ = program_->instructions(); - for (auto &node : instructions_) { - // parse op type infomation - auto op = node.op()->op_info(); - ops_info.insert(op->Type()); - // parse kernel type information - std::string kernel_type_str = - node.kernel()->op_type() + "," + TargetRepr(node.kernel()->target()) + - "," + PrecisionRepr(node.kernel()->precision()) + "," + - DataLayoutRepr(node.kernel()->layout()) + "," + node.kernel()->alias(); - kernels_info.insert(kernel_type_str); + auto block_size = program_->block_size(); + for (size_t block_idx = 0; block_idx < block_size; ++block_idx) { + const auto &insts = program_->instructions(block_idx); + for (auto &inst : insts) { + // parse op type infomation + auto op = inst.op()->op_info(); + ops_info.insert(op->Type()); + // parse kernel type information + std::string kernel_type_str = + inst.kernel()->op_type() + "," + TargetRepr(inst.kernel()->target()) + + "," + PrecisionRepr(inst.kernel()->precision()) + "," + + DataLayoutRepr(inst.kernel()->layout()) + "," + + inst.kernel()->alias(); + kernels_info.insert(kernel_type_str); + } } // get souce_file name from op type and kernel type @@ -170,9 +173,9 @@ void Predictor::PrepareFeedFetch() { std::vector feeds; std::vector fetchs; - const auto &insts = program_->instructions(); - for (size_t i = 0; i < program_->num_instructions(); i++) { - const auto &op = insts[i].op()->op_info(); + const auto &insts = program_->instructions(kRootBlockIdx); + for (auto &inst : insts) { + const auto &op = inst.op()->op_info(); if (op->Type() == "feed") { feeds.push_back(op); } else if (op->Type() == "fetch") { @@ -255,7 +258,6 @@ void Predictor::Build(const lite_api::CxxConfig &config, } else { LOG(INFO) << "Load model from file."; } - Build(model_path, model_file, param_file, @@ -296,10 +298,10 @@ void Predictor::Build(const std::string &model_path, Build(program_desc_, valid_places, passes); } -void Predictor::Build(const std::shared_ptr &desc, +void Predictor::Build(const std::shared_ptr &program_desc, const std::vector &valid_places, const std::vector &passes) { - program_desc_ = desc; + program_desc_ = program_desc; // `inner_places` is used to optimize passes std::vector inner_places = valid_places; for (auto &valid_place : valid_places) { @@ -336,7 +338,7 @@ void Predictor::Build(const std::shared_ptr &desc, Place{TARGET(kARM), PRECISION(kInt8)}); } - Program program(*desc.get(), scope_, inner_places); + Program program(program_desc_, scope_, inner_places); valid_places_ = inner_places; core::KernelPickFactor factor; diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 004fbae071412faeee60d18ad73594956c097297..4520c52ad8c23ba25e5f0cd1be10b19c52d5d0c0 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -58,13 +58,12 @@ class LITE_API Predictor { // Create a predictor with the weight variable scope set. explicit Predictor(const std::shared_ptr& root_scope) : scope_(root_scope) {} - Predictor(const std::shared_ptr& desc, - const std::shared_ptr& root, + Predictor(const std::shared_ptr& program_desc, + const std::shared_ptr& root_scope, const std::vector& valid_places, - const std::vector& var_names = {}) - : program_desc_(desc), scope_(root) { - Program program(*desc.get(), scope_, valid_places, var_names); - // TODO(wilber): rethink a new way to associate config and passes. + const std::vector& vars_to_clone = {}) + : program_desc_(program_desc), scope_(root_scope) { + Program program(program_desc_, scope_, valid_places, vars_to_clone); optimizer_ = Optimizer(std::move(program), valid_places); exec_scope_ = optimizer_.exec_scope(); valid_places_ = valid_places; @@ -86,30 +85,28 @@ class LITE_API Predictor { lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, bool memory_from_memory = false); - void Build(const std::shared_ptr& desc, + void Build(const std::shared_ptr& program_desc, const std::vector& valid_places, const std::vector& passes = {}); std::shared_ptr Clone() const { - auto predictor = - std::make_shared(program_desc_, scope_, valid_places_); - return predictor; + return std::make_shared(program_desc_, scope_, valid_places_); } std::shared_ptr Clone( - const std::vector& var_names) const { + const std::vector& vars_to_clone) const { CHECK(program_desc_) << "Both program and scope of current predicotr " "should be not be nullptr in Clone mode."; CHECK(scope_) << "Both program and scope of current predicotr should be " "not be nullptr in Clone mode."; auto predictor = std::make_shared( - program_desc_, scope_, valid_places_, var_names); + program_desc_, scope_, valid_places_, vars_to_clone); - for (auto i : var_names) { - predictor->exec_scope_->LocalVar(i); - auto* tensor = predictor->scope_->Var(i)->GetMutable(); + for (auto var_name : vars_to_clone) { + predictor->exec_scope_->LocalVar(var_name); + auto* tensor = predictor->scope_->Var(var_name)->GetMutable(); auto* sub_tensor = - predictor->exec_scope_->Var(i)->GetMutable(); + predictor->exec_scope_->Var(var_name)->GetMutable(); sub_tensor->CopyDataFrom(*tensor); } return predictor; @@ -147,6 +144,7 @@ class LITE_API Predictor { // get a const tensor according to its name const lite::Tensor* GetTensor(const std::string& name) const; const RuntimeProgram& runtime_program() const; + Scope* scope() { return scope_.get(); } // This method is disabled in mobile, for unnecessary dependencies required. void SaveModel( diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index cce790544481c1940bd85762aa8259fff4e85b73..394bc6c0b3d5710c41e2afda8fe64ce63a9699d8 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -75,8 +75,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { mode_ = config.power_mode(); threads_ = config.threads(); #ifdef LITE_WITH_NPU + // Store the model-level configuration into scope for kernels, and use + // exe_scope to store the execution-level configuration Context::SetSubgraphModelCacheDir( - config.subgraph_model_cache_dir()); + raw_predictor_->scope(), config.subgraph_model_cache_dir()); #endif #if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index f0d1fb96fe4dfd5f8fa57808a2098cbc42db6a11..9d092a6d385fabde1379b2abecb02daa3270b883 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -22,16 +22,16 @@ namespace lite { void LightPredictor::Build(const std::string& lite_model_file, bool model_from_memory) { if (model_from_memory) { - LoadModelNaiveFromMemory(lite_model_file, scope_.get(), &cpp_program_desc_); + LoadModelNaiveFromMemory( + lite_model_file, scope_.get(), program_desc_.get()); } else { - LoadModelNaiveFromFile(lite_model_file, scope_.get(), &cpp_program_desc_); + LoadModelNaiveFromFile(lite_model_file, scope_.get(), program_desc_.get()); } // For weight quantization of post training, load the int8/16 weights // for optimized model, and dequant it to fp32. DequantizeWeight(); - - BuildRuntimeProgram(cpp_program_desc_); + BuildRuntimeProgram(program_desc_); PrepareFeedFetch(); } @@ -43,15 +43,15 @@ void LightPredictor::Build(const std::string& model_dir, switch (model_type) { #ifndef LITE_ON_TINY_PUBLISH case lite_api::LiteModelType::kProtobuf: - LoadModelPb(model_dir, "", "", scope_.get(), &cpp_program_desc_); + LoadModelPb(model_dir, "", "", scope_.get(), program_desc_.get()); break; #endif case lite_api::LiteModelType::kNaiveBuffer: { if (model_from_memory) { LoadModelNaiveFromMemory( - model_buffer, param_buffer, scope_.get(), &cpp_program_desc_); + model_buffer, param_buffer, scope_.get(), program_desc_.get()); } else { - LoadModelNaive(model_dir, scope_.get(), &cpp_program_desc_); + LoadModelNaive(model_dir, scope_.get(), program_desc_.get()); } break; } @@ -60,7 +60,7 @@ void LightPredictor::Build(const std::string& model_dir, } DequantizeWeight(); - BuildRuntimeProgram(cpp_program_desc_); + BuildRuntimeProgram(program_desc_); PrepareFeedFetch(); } @@ -109,15 +109,17 @@ std::vector LightPredictor::GetOutputNames() { } // append the names of inputs and outputs into input_names_ and output_names_ void LightPredictor::PrepareFeedFetch() { - auto current_block = cpp_program_desc_.GetBlock(0); - std::vector feeds; - std::vector fetchs; - for (size_t i = 0; i < current_block->OpsSize(); i++) { - auto op = current_block->GetOp(i); - if (op->Type() == "feed") { - feeds.push_back(op); - } else if (op->Type() == "fetch") { - fetchs.push_back(op); + std::vector feeds; + std::vector fetchs; + std::shared_ptr program_desc = program_desc_; + auto main_block = program_desc->GetBlock(kRootBlockIdx); + auto op_size = main_block->OpsSize(); + for (size_t op_idx = 0; op_idx < op_size; ++op_idx) { + auto op_desc = main_block->GetOp(op_idx); + if (op_desc->Type() == "feed") { + feeds.push_back(op_desc); + } else if (op_desc->Type() == "fetch") { + fetchs.push_back(op_desc); } } input_names_.resize(feeds.size()); @@ -132,54 +134,35 @@ void LightPredictor::PrepareFeedFetch() { } } -void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { - std::vector insts; - // 1. Create op first - Program program(prog, scope_, {}); - -// 2. Create Instructs -#ifdef LITE_WITH_OPENCL - using OpenCLContext = Context; - std::unique_ptr local_ctx(new KernelContext()); - local_ctx->As().InitOnce(); -#endif - - // Create the kernels of the target places, and filter out the specific - // kernel with the target alias. - for (auto& op : program.ops()) { - auto kernel_type = op->op_info()->GetAttr(kKernelTypeAttr); - std::string op_type, alias; - Place place; - KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); - auto kernels = op->CreateKernels({place}); - // filter out a kernel - auto it = std::find_if( - kernels.begin(), kernels.end(), [&](std::unique_ptr& it) { - return it->alias() == alias; - }); - CHECK(it != kernels.end()); - -#ifdef LITE_WITH_OPENCL - if ((*it)->target() == TARGET(kOpenCL)) { - std::unique_ptr ctx(new KernelContext()); - (*local_ctx).As().CopySharedTo(&ctx->As()); - (*it)->SetContext(std::move(ctx)); - } else { - (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); +void LightPredictor::BuildRuntimeProgram( + const std::shared_ptr& program_desc) { + auto* exe_scope = &scope_->NewScope(); + // Prepare workspace + scope_->Var("feed")->GetMutable>(); + scope_->Var("fetch")->GetMutable>(); + CHECK(program_desc); + auto block_size = program_desc->BlocksSize(); + CHECK(block_size); + for (size_t block_idx = 0; block_idx < block_size; ++block_idx) { + auto block_desc = program_desc->GetBlock(block_idx); + auto var_size = block_desc->VarsSize(); + for (size_t var_idx = 0; var_idx < var_size; ++var_idx) { + auto var_desc = block_desc->GetVar(var_idx); + if (!var_desc->Persistable()) { + exe_scope->Var(var_desc->Name()); + } else { + if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue; + scope_->Var(var_desc->Name()); + } } -#else - (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); -#endif - - insts.emplace_back(op, std::move(*it)); } - program_.reset(new RuntimeProgram(std::move(insts))); - - CHECK(program.exec_scope()); - program_->set_exec_scope(program.exec_scope()); + // Only extracting the ops and generate the runtime program from the main + // block desc + program_.reset(new RuntimeProgram(program_desc, exe_scope, kRootBlockIdx)); } void LightPredictor::DequantizeWeight() { + std::shared_ptr program_desc = program_desc_; #define PROCESS_CONV2D_DATA() \ for (int64_t i = 0; i < ch; ++i) { \ for (int64_t j = 0; j < offset; ++j) { \ @@ -205,10 +188,9 @@ void LightPredictor::DequantizeWeight() { } return result; }; - Tensor tmp_tensor; - for (size_t i = 0; i < cpp_program_desc_.BlocksSize(); i++) { - auto* block = cpp_program_desc_.GetBlock(i); + for (size_t i = 0; i < program_desc->BlocksSize(); i++) { + auto* block = program_desc->GetBlock(i); for (size_t k = 0; k < block->OpsSize(); ++k) { auto* op_desc = block->GetOp(k); if (is_weight_quantized_op(op_desc)) { diff --git a/lite/api/light_api.h b/lite/api/light_api.h index e651d1323a5ce6e36546e9437d06a472eb8a5137..97a46b7d28ffc84feb87283eed9786b562a45229 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -46,6 +46,7 @@ class LITE_API LightPredictor { LightPredictor(const std::string& lite_model_file, bool model_from_memory = false) { scope_ = std::make_shared(); + program_desc_ = std::make_shared(); Build(lite_model_file, model_from_memory); } @@ -57,6 +58,7 @@ class LITE_API LightPredictor { lite_api::LiteModelType model_type = lite_api::LiteModelType::kNaiveBuffer) { scope_ = std::make_shared(); + program_desc_ = std::make_shared(); Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory); } @@ -78,6 +80,7 @@ class LITE_API LightPredictor { std::vector GetInputNames(); std::vector GetOutputNames(); void PrepareFeedFetch(); + Scope* scope() { return scope_.get(); } private: void Build(const std::string& lite_model_file, @@ -91,14 +94,15 @@ class LITE_API LightPredictor { lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, bool model_from_memory = false); - void BuildRuntimeProgram(const cpp::ProgramDesc& prog); + void BuildRuntimeProgram( + const std::shared_ptr& program_desc); void DequantizeWeight(); private: std::shared_ptr scope_; std::unique_ptr program_; - cpp::ProgramDesc cpp_program_desc_; + std::shared_ptr program_desc_; std::vector input_names_; std::vector output_names_; }; diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index 718ba020fb9c6daa4dc4d7263238692267335a48..206222be2818e00777eacdee055d2685f240c1c5 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -38,8 +38,10 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { threads_ = config.threads(); #ifdef LITE_WITH_NPU + // Store the model-level configuration into scope for kernels, and use + // exe_scope to store the execution-level configuration Context::SetSubgraphModelCacheDir( - config.subgraph_model_cache_dir()); + raw_predictor_->scope(), config.subgraph_model_cache_dir()); #endif } diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 2ec4965d3d526c82c41b51954f9564488c5126e1..0218401cb3e34fe8d990aa06fd3b7fcedde4fc99 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -28,6 +28,7 @@ USE_MIR_PASS(graph_visualize_pass); USE_MIR_PASS(remove_tf_redundant_ops_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); +USE_MIR_PASS(lite_conv_conv_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_shuffle_channel_fuse_pass); USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass); @@ -53,6 +54,7 @@ USE_MIR_PASS(mlu_postprocess_pass); USE_MIR_PASS(weight_quantization_preprocess_pass); USE_MIR_PASS(apu_subgraph_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass); +USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass) USE_MIR_PASS(lite_scale_activation_fuse_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass); USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass); diff --git a/lite/backends/arm/math/beam_search.cc b/lite/backends/arm/math/beam_search.cc index 32b7d3bfeba6107493d62a0c9be14a3c15ce7692..74dfa143bda97219874b0e53efc7de34b0416c0e 100644 --- a/lite/backends/arm/math/beam_search.cc +++ b/lite/backends/arm/math/beam_search.cc @@ -234,7 +234,7 @@ void beam_search(const Tensor *pre_ids, selected_ids->Resize(dims); selected_scores->Resize(dims); if (parent_idx) { - parent_idx->Resize(dims); + parent_idx->Resize({static_cast(num_instances)}); } auto *selected_ids_data = selected_ids->mutable_data(); auto *selected_scores_data = selected_scores->mutable_data(); diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index 42a98bc9442b2a619cf5882783bb63f5c4ea7db4..c72223d2e845bc67b541e6f1790e45129deff62f 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -139,6 +139,71 @@ static bool conv_trans_weights_numc(const dtype* din, } return true; } +// for example: m = 4, n = 4 +// din = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9 , 10 ,11], [12, 13, 14, 15]] +// dout = [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]] +/* + m = 8 n = 8: 0 1 2 3 4 5 6 7 0 8 16 24 32 40 48 56 + 16 17 18 19 20 21 22 23 2 10 18 26 34 42 50 58 + 24 25 26 27 28 29 30 31 3 11 19 27 35 43 51 59 + 32 33 34 35 36 37 38 39 4 12 20 28 36 44 52 60 ... + } + } +*/ +template +void local_transpose(const Dtype* din, Dtype* dout, int m, int n) { + // n % 4 == 0 && m % 4 == 0 + // n * m ==> n * m data trans + int offset_m = m << 2; + const Dtype* din_ptr = din; + Dtype* dout_ptr = dout; + for (int i = 0; i < n; i += 4) { + Dtype* out_ptr0 = dout_ptr; + Dtype* out_ptr1 = dout_ptr + m; + Dtype* out_ptr2 = out_ptr1 + m; + Dtype* out_ptr3 = out_ptr2 + m; + const Dtype* in_ptr0 = din_ptr; + const Dtype* in_ptr1 = din_ptr + m; + const Dtype* in_ptr2 = in_ptr1 + m; + const Dtype* in_ptr3 = in_ptr2 + m; + for (int j = 0; j < m; j += 4) { + float32x4_t vin0 = vld1q_f32(in_ptr0); + float32x4_t vin1 = vld1q_f32(in_ptr1); + float32x4_t vin2 = vld1q_f32(in_ptr2); + float32x4_t vin3 = vld1q_f32(in_ptr3); + // a00 b00 a02 b02 a01 b01 a03 b03 + float32x4x2_t tmp0 = vtrnq_f32(vin0, vin1); + // c00 d00 c02 d02 c01 d01 c03 d03 + float32x4x2_t tmp2 = vtrnq_f32(vin2, vin3); + in_ptr0 = in_ptr3 + m; + in_ptr1 = in_ptr3 + 2 * m; + float tmp_val1 = tmp0.val[0][2]; + float tmp_val2 = tmp0.val[0][3]; + tmp0.val[0][2] = tmp2.val[0][0]; + tmp0.val[0][3] = tmp2.val[0][1]; + float tmp_val3 = tmp0.val[1][2]; + float tmp_val4 = tmp0.val[1][3]; + tmp2.val[0][0] = tmp_val1; + tmp2.val[0][1] = tmp_val2; + tmp0.val[1][2] = tmp2.val[1][0]; + tmp0.val[1][3] = tmp2.val[1][1]; + tmp2.val[1][0] = tmp_val3; + tmp2.val[1][1] = tmp_val4; + in_ptr2 = in_ptr1 + m; + in_ptr3 = in_ptr1 + 2 * m; + vst1q_f32(out_ptr0, tmp0.val[0]); + vst1q_f32(out_ptr1, tmp0.val[1]); + out_ptr0 += 4; + out_ptr1 += 4; + vst1q_f32(out_ptr2, tmp2.val[0]); + vst1q_f32(out_ptr3, tmp2.val[1]); + out_ptr2 += 4; + out_ptr3 += 4; + } + dout_ptr += offset_m; + din_ptr += 4; + } +} template void transpose(const Dtype* din, Dtype* dout, int m, int n) { // nxm == mxn diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index 04373992e4802a0b0c2529daac851e00ebcb56cf..a73a63ddcb67f8790f73aff3fff8368f4005b7e1 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -747,6 +747,16 @@ void elementwise_mul(const int* dinx, } } +template <> +void elementwise_mul(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int num) { + for (int i = 0; i < num; i++) { + dout[i] = dinx[i] * diny[i]; + } +} + template <> void elementwise_mul_relu(const float* dinx, const float* diny, @@ -801,6 +811,17 @@ void elementwise_mul_relu(const float* dinx, } } +template <> +void elementwise_mul_relu(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int num) { + for (int i = 0; i < num; i++) { + int64_t tmp = dinx[i] * diny[i]; + dout[i] = tmp > 0 ? tmp : 0; + } +} + template <> void elementwise_mul_broadcast(const float* dinx, const float* diny, @@ -935,6 +956,29 @@ void elementwise_mul_broadcast(const int* dinx, } } +template <> +void elementwise_mul_broadcast(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const int64_t* dinx_ptr = dinx + offset; + const int64_t diny_data = diny[j]; + int64_t* dout_ptr = dout + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = *dinx_ptr * diny_data; + dout_ptr++; + dinx_ptr++; + } + } + } +} + template <> void elementwise_mul_relu_broadcast(const float* dinx, const float* diny, @@ -1014,6 +1058,30 @@ void elementwise_mul_relu_broadcast(const float* dinx, } } +template <> +void elementwise_mul_relu_broadcast(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const int64_t* dinx_ptr = dinx + offset; + const int64_t diny_data = diny[j]; + int64_t* dout_ptr = dout + offset; + for (int k = 0; k < num; ++k) { + int64_t tmp = *dinx_ptr * diny_data; + *dout_ptr = tmp > 0 ? tmp : 0; + dout_ptr++; + dinx_ptr++; + } + } + } +} + template <> void elementwise_max(const float* dinx, const float* diny, diff --git a/lite/backends/arm/math/prior_box.cc b/lite/backends/arm/math/prior_box.cc index 6daab69ebf00da24d67132afba4b9abef0afbd39..4ef7356e67cee4c47ddf3eb16ed5286b4271b41a 100644 --- a/lite/backends/arm/math/prior_box.cc +++ b/lite/backends/arm/math/prior_box.cc @@ -21,7 +21,7 @@ namespace lite { namespace arm { namespace math { -const int MALLOC_ALIGN = 64; +const int MALLOC_ALIGN = 16; void* fast_malloc(size_t size) { size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; diff --git a/lite/backends/host/target_wrapper.cc b/lite/backends/host/target_wrapper.cc index 5f020662a9d74aab6c28f79221d670e5de5ae048..00ce9dd6b349decc2f603692c2a6a0801bd4d7c0 100644 --- a/lite/backends/host/target_wrapper.cc +++ b/lite/backends/host/target_wrapper.cc @@ -19,7 +19,7 @@ namespace paddle { namespace lite { -const int MALLOC_ALIGN = 64; +const int MALLOC_ALIGN = 16; void* TargetWrapper::Malloc(size_t size) { size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; @@ -30,7 +30,6 @@ void* TargetWrapper::Malloc(size_t size) { void* r = reinterpret_cast(reinterpret_cast(p + offset) & (~(MALLOC_ALIGN - 1))); static_cast(r)[-1] = p; - memset(r, 0, size); return r; } void TargetWrapper::Free(void* ptr) { diff --git a/lite/backends/npu/device.cc b/lite/backends/npu/device.cc index 22f760e39f86b29ccf025a83b2a43c87882f9e02..2b2d5321ba6dbac7ff002039c3c8a0423cbe0a6e 100644 --- a/lite/backends/npu/device.cc +++ b/lite/backends/npu/device.cc @@ -33,7 +33,7 @@ std::shared_ptr Device::Load( // Check HiAI DDK version const char* ddk_version = model_client->GetVersion(); if (ddk_version) { - LOG(INFO) << "[NPU] HiAI DDK version: " << ddk_version; + VLOG(3) << "[NPU] HiAI DDK version: " << ddk_version; } else { LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!"; } diff --git a/lite/backends/xpu/debug.h b/lite/backends/xpu/debug.h index 75d18b6f4bf461a871c26c7665d8b48bc2f3db38..56bafc9c3d3a7772af8fc8afd10fc7efa3415ef7 100644 --- a/lite/backends/xpu/debug.h +++ b/lite/backends/xpu/debug.h @@ -19,7 +19,7 @@ #include #include #include -#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/backends/xpu/target_wrapper.h" namespace paddle { namespace lite { @@ -82,8 +82,8 @@ void DumpXPUMem(const T* ptr, size_t item_per_line = 30) { size_t after_stride_len = (len + stride - 1) / stride; std::unique_ptr cpu_mem(new T[len]); - xpu_memcpy( - cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST); + XPU_CALL(xpu_memcpy( + cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST)); std::unique_ptr after_stride(new T[after_stride_len]); for (size_t i = 0; i < after_stride_len; ++i) { after_stride[i] = cpu_mem[i * stride]; diff --git a/lite/backends/xpu/target_wrapper.cc b/lite/backends/xpu/target_wrapper.cc index 85a0023590858ab72e9e4f258d62dce809888918..a322418ccde20a34dc6c6ba9b47601a9a658f99c 100644 --- a/lite/backends/xpu/target_wrapper.cc +++ b/lite/backends/xpu/target_wrapper.cc @@ -19,11 +19,11 @@ namespace lite { void* TargetWrapperXPU::Malloc(size_t size) { void* ptr{nullptr}; - xpu_malloc(&ptr, size); + XPU_CALL(xpu_malloc(&ptr, size)); return ptr; } -void TargetWrapperXPU::Free(void* ptr) { xpu_free(ptr); } +void TargetWrapperXPU::Free(void* ptr) { XPU_CALL(xpu_free(ptr)); } void TargetWrapperXPU::MemcpySync(void* dst, const void* src, @@ -31,10 +31,10 @@ void TargetWrapperXPU::MemcpySync(void* dst, IoDirection dir) { switch (dir) { case IoDirection::HtoD: - xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE)); break; case IoDirection::DtoH: - xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST); + XPU_CALL(xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST)); break; default: LOG(FATAL) << "Unsupported IoDirection " << static_cast(dir); @@ -49,7 +49,7 @@ XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size, } else { ptr = TargetWrapperXPU::Malloc(size); } - CHECK(ptr != nullptr); + CHECK(ptr != nullptr) << "size = " << size << ", use_l3 = " << use_l3; return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3)); } diff --git a/lite/backends/xpu/target_wrapper.h b/lite/backends/xpu/target_wrapper.h index b84b5d75e74a14e81091b003aa3ae5514e53a42c..070184a13088a169fe38f1b8105a0803d9915da1 100644 --- a/lite/backends/xpu/target_wrapper.h +++ b/lite/backends/xpu/target_wrapper.h @@ -16,11 +16,23 @@ #include // std::unique_ptr #include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free -#include "lite/core/target_wrapper.h" +#include "lite/core/target_wrapper.h" // TargetWrapper +#include "lite/utils/cp_logging.h" // CHECK_EQ + +#define XPU_CALL(func) \ + { \ + auto e = (func); \ + CHECK_EQ(e, 0) << "XPU: (" << #func << ") returns " << e; \ + } namespace paddle { namespace lite { +// MAX(lod.size()) = 64 +const int XPU_MAX_LOD_SIZE = 64; +// MAX(lod[i + 1] - lod[i]) = 512 +const int XPU_MAX_LOD_SEQ_LEN = 512; + using TargetWrapperXPU = TargetWrapper; struct XPUScratchPad { @@ -33,7 +45,7 @@ struct XPUScratchPad { struct XPUScratchPadDeleter { void operator()(XPUScratchPad* sp) const { if (!sp->is_l3_) { - xpu_free(sp->addr_); + XPU_CALL(xpu_free(sp->addr_)); } delete sp; } @@ -55,7 +67,7 @@ class TargetWrapper { size_t size, IoDirection dir); - static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = true); + static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = false); static xdnn::Context* GetRawContext() { if (tls_raw_ctx_ == nullptr) { @@ -77,11 +89,10 @@ class TargetWrapper { static void SetDev(int dev_no = 0) { const char* dev_env = getenv("LITE_XPU_DEV"); if (dev_env) { - xpu_set_device(atoi(dev_env)); - return; + dev_no = atoi(dev_env); } - xpu_set_device(dev_no); + XPU_CALL(xpu_set_device(dev_no)); } static std::string multi_encoder_precision; // NOLINT diff --git a/lite/core/arena/framework.cc b/lite/core/arena/framework.cc index 1138a3bcc2e3e3f3c77d94bf8128b8231f930550..9dfb585a0ce3d2d19019cbcd5d3bc470a8d5291e 100644 --- a/lite/core/arena/framework.cc +++ b/lite/core/arena/framework.cc @@ -32,25 +32,27 @@ void TestCase::CreateInstruction() { #endif if (enable_subgraph_op) { // Create a new block desc to wrap the original op desc + auto sub_program_desc = std::make_shared(); int sub_block_idx = 0; - auto sub_block_desc = new cpp::BlockDesc(); + auto sub_block_desc = sub_program_desc->AddBlock(); sub_block_desc->ClearOps(); sub_block_desc->ClearVars(); - auto sub_block_op_desc = sub_block_desc->AddOp(); - *sub_block_op_desc = *op_desc_; + auto sub_op_desc = sub_block_desc->AddOp(); + *sub_op_desc = *op_desc_; // Add the block desc into the subgraph op which used to replace the // original op op_desc_.reset(new cpp::OpDesc()); op_desc_->SetType("subgraph"); op_desc_->SetAttr("sub_block", sub_block_idx); - auto in_names = sub_block_op_desc->input_vars(); - auto out_names = sub_block_op_desc->output_vars(); + auto in_names = sub_op_desc->input_vars(); + auto out_names = sub_op_desc->output_vars(); op_desc_->SetInput("Inputs", in_names); op_desc_->SetOutput("Outputs", out_names); op_desc_->SetAttr>("input_data_names", in_names); op_desc_->SetAttr>("output_data_names", out_names); op = LiteOpRegistry::Global().Create(op_desc().Type()); - static_cast(op.get())->SetSubBlock(sub_block_desc); + static_cast(op.get())->SetProgramDesc( + sub_program_desc); } else { op = LiteOpRegistry::Global().Create(op_desc().Type()); } @@ -60,7 +62,7 @@ void TestCase::CreateInstruction() { // filter out the target kernel CHECK(!kernels.empty()) << "No kernel found for place " << place_.DebugString(); - auto it = std::remove_if( + auto it = std::find_if( kernels.begin(), kernels.end(), [&](std::unique_ptr& k) { return k->alias() == alias_; }); @@ -234,19 +236,6 @@ bool TestCase::CheckPrecision(const std::string& var_name, return success; } -TestCase::~TestCase() { - if (op_desc_->Type() == "subgraph") { - // Release the subblock desc of Subgraph op - auto subgraph_op = const_cast( - static_cast(instruction_->op())); - CHECK(subgraph_op); - auto sub_block_desc = subgraph_op->GetSubBlock(); - if (sub_block_desc) { - delete sub_block_desc; - } - } -} - } // namespace arena } // namespace lite } // namespace paddle diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h index 4e73768e53576f03e47158618fa4f0eac0851382..4ccb05428d38c65f8cad36f1702c034cfe62705b 100644 --- a/lite/core/arena/framework.h +++ b/lite/core/arena/framework.h @@ -46,7 +46,7 @@ class TestCase { base_scope_(new Scope) { ctx_ = ContextScheduler::Global().NewContext(place_.target); } - virtual ~TestCase(); + virtual ~TestCase() {} void Prepare() { PrepareData(); diff --git a/lite/core/context.cc b/lite/core/context.cc index f14d1dfddea806ab3839f6f897b9d4d3fe396ca8..bda8a51f2b286cea0735898884c9dccd516d0055 100644 --- a/lite/core/context.cc +++ b/lite/core/context.cc @@ -17,10 +17,6 @@ namespace paddle { namespace lite { -#ifdef LITE_WITH_NPU -std::string Context::subgraph_model_cache_dir_{""}; // NOLINT -#endif - #ifdef LITE_WITH_MLU int Context::next_queue_id_{0}; std::map Context::queue_id_map_; diff --git a/lite/core/context.h b/lite/core/context.h index c3993d9589eeac442eaa827152fd1293852396db..a4c338ab0094f88b52f8df1b94dcd4ee272e55a6 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -39,6 +39,7 @@ #include #include #include "lite/core/device_info.h" +#include "lite/core/scope.h" #include "lite/core/target_wrapper.h" #include "lite/core/tensor.h" #include "lite/utils/all.h" @@ -84,15 +85,19 @@ class Context { NPUContext& operator=(const NPUContext& ctx) {} std::string name() const { return "NPUContext"; } - static void SetSubgraphModelCacheDir(std::string subgraph_model_cache_dir) { - subgraph_model_cache_dir_ = subgraph_model_cache_dir; + static void SetSubgraphModelCacheDir(Scope* scope, + std::string subgraph_model_cache_dir) { + auto var = scope->Var("SUBGRAPH_MODEL_CACHE_DIR"); + CHECK(var); + auto data = var->GetMutable(); + CHECK(data); + *data = subgraph_model_cache_dir; } - static std::string SubgraphModelCacheDir() { - return subgraph_model_cache_dir_; + static std::string SubgraphModelCacheDir(Scope* scope) { + auto var = scope->FindVar("SUBGRAPH_MODEL_CACHE_DIR"); + if (!var) return ""; + return var->Get(); } - - private: - static std::string subgraph_model_cache_dir_; }; #endif diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index be09ed4b1a63154b8561f4d39cff7d987a9fcba7..cd129b332fa79dc45d74dc8a0befc1e67a68c316 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -18,6 +18,7 @@ lite_cc_library(mir_passes fusion/conv_activation_fuse_pass.cc fusion/var_conv_2d_activation_fuse_pass.cc fusion/conv_bn_fuse_pass.cc + fusion/conv_conv_fuse_pass.cc fusion/elementwise_add_activation_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc fusion/sequence_pool_concat_fuse_pass.cc @@ -32,6 +33,7 @@ lite_cc_library(mir_passes elimination/identity_dropout_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc elimination/remove_tf_redundant_ops_pass.cc + elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_cast_pass.cc diff --git a/lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.cc b/lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..7866cb956c4e51d3b69687751325ca3ff4eda9d6 --- /dev/null +++ b/lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.cc @@ -0,0 +1,244 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +// Remove all of the unused nodes from the contorl flow op and update the inputs +// and outputs of the op info The unused nodes are defined as the nodes which +// are only linked to the control flow op nodes but nerver linked to the other +// op nodes. +// +// For example: +// graph[0]: main block +// in_x +// in_f | in_z(unused node) +// \ | / +// \ | / +// in_w ------- while ------- in_y(unused_node) +// / | +// / | +// (unused node)out_y | +// out_x +// +// graph[1]: sub block +// in_x +// | +// | +// conv2d----in_f +// | +// | +// fc ------in_w +// | +// | +// softmax +// | +// | +// out_x +// +// After the pass is applied: +// in_x +// in_f | +// \ | +// \ | +// in_w ------- while +// | +// | +// | +// out_x + +// Remove the var node from var2rm if it is recursively referred to any op in +// the subblock +void CollectUnusedInputOutputNodes( + int block_idx, + std::vector>* graphs, + const std::unordered_set& control_flow_op_types, + std::unordered_map* in_vars2rm, + std::unordered_map* out_vars2rm) { + auto block_size = graphs->size(); + for (auto& op_node : (*graphs)[block_idx]->StmtTopologicalOrder()) { + if (!op_node->IsStmt()) continue; + auto op_info = op_node->AsStmt().op_info(); + auto op_type = op_info->Type(); + if (control_flow_op_types.count(op_type)) { + int sub_block_idx = op_info->GetAttr("sub_block"); + CHECK(block_idx >= 0 && block_idx < block_size); + CollectUnusedInputOutputNodes(sub_block_idx, + graphs, + control_flow_op_types, + in_vars2rm, + out_vars2rm); + } else { + for (auto& var_node : op_node->inlinks) { + auto& var_name = var_node->AsArg().name; + if (in_vars2rm->count(var_name)) { + in_vars2rm->erase(var_name); + } + } + for (auto& var_node : op_node->outlinks) { + auto& var_name = var_node->AsArg().name; + // Tensor array may be only used as the output vars in the sublock + if (in_vars2rm->count(var_name)) { + in_vars2rm->erase(var_name); + } + if (out_vars2rm->count(var_name)) { + out_vars2rm->erase(var_name); + } + } + } + } +} + +// Remove the unused var nodes from the graph and update the op_info of the +// control flow op +void RemoveNodesFromGraphAndUpdateOpInfo( + SSAGraph* graph, + Node* op_node, + const std::unordered_map& in_vars2rm, + const std::unordered_map& out_vars2rm) { + auto op_info = op_node->AsStmt().mutable_op_info(); + auto op_type = op_info->Type(); + // Unlink the in_vars2rm and out_vars2rm from the control flow op node, and + // remove them if nerver used. + for (auto& var_node : in_vars2rm) { + VLOG(3) << "in var node '" << var_node.first << "' is unlinked to " + << op_type; + RemoveDirectedLink(var_node.second, op_node); + } + for (auto& var_node : out_vars2rm) { + VLOG(3) << "out var node '" << var_node.first << "' is unlinked from " + << op_type; + RemoveDirectedLink(op_node, var_node.second); + // Unlink from all of the out op nodes. + std::unordered_set out_op_nodes; + for (auto* out_op_node : var_node.second->outlinks) { + if (!out_op_nodes.count(out_op_node)) { + out_op_nodes.insert(out_op_node); + } + } + for (auto* out_op_node : out_op_nodes) { + RemoveDirectedLink(var_node.second, out_op_node); + } + } + // Remove the unused nodes from the graph if their inlinks and outlinks are + // empty + std::unordered_set removed_var_nodes; + for (auto& var_node : in_vars2rm) { + if (var_node.second->inlinks.empty() && var_node.second->outlinks.empty() && + !removed_var_nodes.count(var_node.second)) { + removed_var_nodes.insert(var_node.second); + graph->RemoveNode(var_node.second); + VLOG(3) << "in var node " << var_node.first << " is removed"; + } + } + for (auto& var_node : out_vars2rm) { + if (var_node.second->inlinks.empty() && var_node.second->outlinks.empty() && + !removed_var_nodes.count(var_node.second)) { + removed_var_nodes.insert(var_node.second); + graph->RemoveNode(var_node.second); + VLOG(3) << "out var node " << var_node.first << " is removed"; + } + } + // Update the op info of the control flow op + for (auto& input : *op_info->mutable_inputs()) { + for (auto var = input.second.begin(); var != input.second.end();) { + if (in_vars2rm.count(*var)) { + var = input.second.erase(var); + } else { + ++var; + } + } + } + for (auto& output : *op_info->mutable_outputs()) { + for (auto var = output.second.begin(); var != output.second.end();) { + if (out_vars2rm.count(*var)) { + var = output.second.erase(var); + } else { + ++var; + } + } + } +} + +void ControlFlowOpUnusedInputsAndOutputsEliminatePass::SetAllGraphs( + std::vector>* graphs) { + CHECK(graphs && !graphs->empty()); + graphs_ = graphs; +} + +void ControlFlowOpUnusedInputsAndOutputsEliminatePass::Apply( + const std::unique_ptr& graph) { + // Remove the unused input and output nodes from the control flow op nodes + // Which are only linked to the control flow op nodes but nerver linked to the + // other op nodes + const std::unordered_set control_flow_op_types = { + "while", "conditional_block"}; + auto block_size = graphs_->size(); + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (!op_node->IsStmt()) continue; + auto op_info = op_node->AsStmt().mutable_op_info(); + auto op_type = op_info->Type(); + if (!control_flow_op_types.count(op_type)) continue; + int sub_block_idx = op_info->GetAttr("sub_block"); + CHECK(sub_block_idx >= 0 && sub_block_idx < block_size); + // Initialize the unused nodes with all of the input and output nodes + std::unordered_map in_vars2rm, out_vars2rm; + for (auto* var_node : op_node->inlinks) { + auto& var_name = var_node->AsArg().name; + if (!in_vars2rm.count(var_name)) { + in_vars2rm.insert(std::pair(var_name, var_node)); + } + } + for (auto* var_node : op_node->outlinks) { + auto& var_name = var_node->AsArg().name; + if (!out_vars2rm.count(var_name)) { + out_vars2rm.insert(std::pair(var_name, var_node)); + } + } + // Remove the nodes which used in subblock recursively, and the remaining + // nodes are the unused one. + CollectUnusedInputOutputNodes(sub_block_idx, + graphs_, + control_flow_op_types, + &in_vars2rm, + &out_vars2rm); + if (in_vars2rm.size() > 0 || out_vars2rm.size() > 0) { + // Remove the unused nodes from graph, and update the op info of the + // control flow op + RemoveNodesFromGraphAndUpdateOpInfo( + graph.get(), op_node, in_vars2rm, out_vars2rm); + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS( + control_flow_op_unused_inputs_and_outputs_eliminate_pass, + paddle::lite::mir::ControlFlowOpUnusedInputsAndOutputsEliminatePass) + .BindTargets({TARGET(kNPU)}); diff --git a/lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h b/lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..2863661de1e93d15bfe835e39033d4ecaee6d8cc --- /dev/null +++ b/lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ControlFlowOpUnusedInputsAndOutputsEliminatePass : public mir::StmtPass { + public: + void Apply(const std::unique_ptr &graph) override; + void SetAllGraphs(std::vector> *graphs); + + private: + std::vector> *graphs_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index a7a4cee798c1e8ef5b9b8f8d9e8e5810554fc571..95723bbd21dc02ed8bb5b46c48f9836d3f9aff1f 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -16,6 +16,9 @@ lite_cc_library(fuse_var_conv_activation lite_cc_library(fuse_conv_bn SRCS conv_bn_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_conv_conv + SRCS conv_conv_fuser.cc + DEPS pattern_matcher_high_api) lite_cc_library(fuse_elementwise_add_activation SRCS elementwise_add_activation_fuser.cc DEPS pattern_matcher_high_api) @@ -42,6 +45,7 @@ set(mir_fusers fuse_conv_activation fuse_var_conv_activation fuse_conv_bn + fuse_conv_conv fuse_quant_dequant fuse_elementwise_add_activation fuse_transpose_softmax_transpose diff --git a/lite/core/mir/fusion/__xpu__mmdnn_fuse_pass.cc b/lite/core/mir/fusion/__xpu__mmdnn_fuse_pass.cc index 61aeb2ab1f51ddcd6b153971253f8239472a1031..db950fd4b4d671ed618c8bc53010e5be6f5fd78b 100644 --- a/lite/core/mir/fusion/__xpu__mmdnn_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__mmdnn_fuse_pass.cc @@ -326,6 +326,28 @@ class XPUMmdnnSearchAttentionFuser : public FuseBase { } }; +// 4 inputs +// ======== +// +// input_x +// input_y +// topk_row +// topk_col +// +// input_x ------- match_matrix_tensor ------- input_y +// | +// relu +// ________/ \________ +// | | +// var_conv_2d | +// | | +// relu | +// |_______ _______| +// \ / +// sequence_concat +// | +// topk_row ---- sequence_topk_avg_pooling ----- topk_col +// class XPUMmdnnMatchConvTopkFuser : public FuseBase { public: void BuildPattern() override { @@ -418,10 +440,156 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase { auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info(); op_desc.SetAttr("input_w_max", - match_op_info->GetAttr("w_max")); + match_op_info->GetAttr("__xpu__w_max")); + op_desc.SetAttr("dim_t", match_op_info->GetAttr("dim_t")); + auto* conv_op_info = matched.at("conv")->stmt()->op_info(); + op_desc.SetAttr("conv_w_max", + conv_op_info->GetAttr("__xpu__w_max")); + op_desc.SetAttr("output_channel", + conv_op_info->GetAttr("OutputChannel")); + auto* topk_op_info = matched.at("topk")->stmt()->op_info(); + op_desc.SetAttr>( + "topks", topk_op_info->GetAttr>("topks")); + op_desc.SetAttr("channel_num", + topk_op_info->GetAttr("channel_num")); + + auto* new_stmt = matched.at("match_matrix_tensor")->stmt(); + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, new_stmt->op()->scope()); + new_op->SetValidPlaces(new_stmt->op()->valid_places()); + auto kernels = new_op->CreateKernels(new_op->valid_places()); + new_stmt->SetOp(new_op); + new_stmt->SetKernels(std::move(kernels)); + + // XXX(miaotianxiang): redundant links around |topk| are automatically + // removed as |topk| is marked intermediate. + // RemoveDirectedLink(matched.at("topk_col"), matched.at("topk")); + // RemoveDirectedLink(matched.at("topk_row"), matched.at("topk")); + std::vector arg_names{"conv_w"}; + for (auto name : arg_names) { + DirectedLink(matched.at(name), matched.at("match_matrix_tensor")); + } + std::vector out_names{"topk_out"}; + for (auto name : out_names) { + IR_OP_VAR_LINK(matched.at("match_matrix_tensor"), matched.at(name)); + } + } +}; + +// 2 inputs +// ======== +// +// input_x +// input_y +// +// input_x ------- match_matrix_tensor ------- input_y +// | | | +// | relu | +// | ________/ \________ | +// | | | | +// | var_conv_2d | | +// | | | | +// | relu | | +// | |_______ _______| | +// | \ / | +// | sequence_concat | +// | | | +// |--------- sequence_topk_avg_pooling -------| +// +class XPUMmdnnMatchConvTopkFuser2 : public FuseBase { + public: + void BuildPattern() override { + auto* input_x = VarNode("input_x") + ->assert_is_op_input("match_matrix_tensor", "X") + ->assert_is_op_input("sequence_topk_avg_pooling", "ROW") + ->AsInput(); + auto* input_y = + VarNode("input_y") + ->assert_is_op_input("match_matrix_tensor", "Y") + ->assert_is_op_input("sequence_topk_avg_pooling", "COLUMN") + ->AsInput(); + auto* input_w = VarNode("input_w") + ->assert_is_op_input("match_matrix_tensor", "W") + ->AsInput(); + + auto* match_matrix_tensor = + OpNode("match_matrix_tensor", "match_matrix_tensor"); + auto* match_out = VarNode("match_out") + ->assert_is_op_output("match_matrix_tensor", "Out") + ->AsIntermediate(); + auto* match_tmp = VarNode("match_tmp") + ->assert_is_op_output("match_matrix_tensor", "Tmp") + ->AsIntermediate(); + auto* relu0 = OpNode("relu0", "relu")->AsIntermediate(); + auto* relu0_out = VarNode("relu0_out") + ->assert_is_op_output("relu", "Out") + ->AsIntermediate(); + auto* conv_w = + VarNode("conv_w")->assert_is_op_input("var_conv_2d", "W")->AsInput(); + auto* conv = OpNode("conv", "var_conv_2d")->AsIntermediate(); + auto* conv_out = VarNode("conv_out") + ->assert_is_op_output("var_conv_2d", "Out") + ->AsIntermediate(); + auto* conv_col = VarNode("conv_col") + ->assert_is_op_output("var_conv_2d", "Col") + ->AsIntermediate(); + auto* relu1 = OpNode("relu1", "relu")->AsIntermediate(); + auto* relu1_out = VarNode("relu1_out") + ->assert_is_op_output("relu", "Out") + ->AsIntermediate(); + auto* seq_concat = + OpNode("seq_concat", "sequence_concat")->AsIntermediate(); + auto* seq_concat_out = + VarNode("seq_concat_out") + ->assert_is_op_output("sequence_concat", "Out") + ->assert_is_op_input("sequence_topk_avg_pooling", "X") + ->AsIntermediate(); + auto* topk = OpNode("topk", "sequence_topk_avg_pooling")->AsIntermediate(); + auto* topk_out = + VarNode("topk_out") + ->assert_is_op_output("sequence_topk_avg_pooling", "Out") + ->AsOutput(); + auto* topk_pos = + VarNode("topk_pos") + ->assert_is_op_output("sequence_topk_avg_pooling", "pos") + ->AsIntermediate(); + + *input_x >> *match_matrix_tensor; + *input_y >> *match_matrix_tensor; + *input_w >> *match_matrix_tensor; + *match_matrix_tensor >> *match_out >> *relu0 >> *relu0_out; + *match_matrix_tensor >> *match_tmp; + + *relu0_out >> *conv >> *conv_out >> *relu1 >> *relu1_out; + *conv_w >> *conv; + *conv >> *conv_col; + + *relu0_out >> *seq_concat; + *relu1_out >> *seq_concat; + *seq_concat >> *seq_concat_out >> *topk >> *topk_out; + *input_x >> *topk; + *input_y >> *topk; + *topk >> *topk_pos; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__mmdnn_match_conv_topk"); + op_desc.SetInput("input_x", {matched.at("input_x")->arg()->name}); + op_desc.SetInput("input_y", {matched.at("input_y")->arg()->name}); + op_desc.SetInput("input_w", {matched.at("input_w")->arg()->name}); + op_desc.SetInput("conv_w", {matched.at("conv_w")->arg()->name}); + op_desc.SetOutput("topk_out", {matched.at("topk_out")->arg()->name}); + + auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info(); + op_desc.SetAttr("input_w_max", + match_op_info->GetAttr("__xpu__w_max")); op_desc.SetAttr("dim_t", match_op_info->GetAttr("dim_t")); auto* conv_op_info = matched.at("conv")->stmt()->op_info(); - op_desc.SetAttr("conv_w_max", conv_op_info->GetAttr("w_max")); + op_desc.SetAttr("conv_w_max", + conv_op_info->GetAttr("__xpu__w_max")); + op_desc.SetAttr("output_channel", + conv_op_info->GetAttr("OutputChannel")); auto* topk_op_info = matched.at("topk")->stmt()->op_info(); op_desc.SetAttr>( "topks", topk_op_info->GetAttr>("topks")); @@ -437,8 +605,7 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase { new_stmt->SetKernels(std::move(kernels)); // XXX(miaotianxiang): redundant links around |topk| are automatically - // removed as |topk| is - // marked intermediate. + // removed as |topk| is marked intermediate. // RemoveDirectedLink(matched.at("topk_col"), matched.at("topk")); // RemoveDirectedLink(matched.at("topk_row"), matched.at("topk")); std::vector arg_names{"conv_w"}; @@ -624,6 +791,15 @@ class XPUMmdnnBidEmbAttFuser : public FuseBase { } }; +// 5 outputs +// ========= +// +// eltwise01_out +// seq_pool_right_out +// seq_pool_left_out +// seq_pool_2in1_out +// concat_3in1_out +// class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { public: void BuildPattern() override { @@ -818,17 +994,272 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info(); op_desc.SetAttr>( "grnn_fw_wh_maxs", - grnn_fw_op_info->GetAttr>("wh_max")); + grnn_fw_op_info->GetAttr>("__xpu__wh_max")); op_desc.SetAttr>( "grnn_fw_wi_maxs", - grnn_fw_op_info->GetAttr>("wi_max")); + grnn_fw_op_info->GetAttr>("__xpu__wi_max")); auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info(); op_desc.SetAttr>( "grnn_rv_wh_maxs", - grnn_rv_op_info->GetAttr>("wh_max")); + grnn_rv_op_info->GetAttr>("__xpu__wh_max")); op_desc.SetAttr>( "grnn_rv_wi_maxs", - grnn_rv_op_info->GetAttr>("wi_max")); + grnn_rv_op_info->GetAttr>("__xpu__wi_max")); + auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info(); + op_desc.SetAttr("att_fc_w_max", + att_fc_op_info->GetAttr("W_max")); + + auto* new_stmt = matched.at("emb0")->stmt(); + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, new_stmt->op()->scope()); + new_op->SetValidPlaces(new_stmt->op()->valid_places()); + auto kernels = new_op->CreateKernels(new_op->valid_places()); + new_stmt->SetOp(new_op); + new_stmt->SetKernels(std::move(kernels)); + + std::vector arg_names{ + "input1", + "grnn_left_wh", + "grnn_left_wi", + "grnn_right_wh", + "grnn_right_wi", + "att_2in1_w", + "att_2in1_b", + }; + for (auto name : arg_names) { + DirectedLink(matched.at(name), matched.at("emb0")); + } + std::vector out_names{ + "seq_pool_left_out", + "seq_pool_right_out", + "seq_pool_2in1_out", + "concat_3in1_out", + "eltwise01_out", + }; + for (auto name : out_names) { + IR_OP_VAR_LINK(matched.at("emb0"), matched.at(name)); + } + } +}; + +// 6 outputs +// ========= +// +// emb0_out +// eltwise01_out +// seq_pool_right_out +// seq_pool_left_out +// seq_pool_2in1_out +// concat_3in1_out +// +class XPUMmdnnBidEmbGrnnAttFuser2 : public FuseBase { + public: + void BuildPattern() override { + auto* input0 = VarNode("input0")->AsInput(); + auto* input1 = VarNode("input1")->AsInput(); + auto* emb_tbl = VarNode("emb_tbl")->AsInput(); + + auto* emb0 = OpNode("emb0", "lookup_table"); + auto* emb0_out = VarNode("emb0_out") + ->assert_is_op_output("lookup_table", "Out") + ->assert_is_op_input("search_seq_arithmetic", "X") + ->AsOutput(); + auto* emb1 = OpNode("emb1", "lookup_table")->AsIntermediate(); + auto* emb1_out = VarNode("emb1_out") + ->assert_is_op_output("lookup_table", "Out") + ->assert_is_op_input("search_seq_arithmetic", "Y") + ->AsIntermediate(); + auto* eltwise01 = + OpNode("eltwise01", "search_seq_arithmetic")->AsIntermediate(); + auto* eltwise01_out = + VarNode("eltwise01_out") + ->assert_is_op_output("search_seq_arithmetic", "Out") + ->AsOutput(); + + auto* seq_rev_right0 = + OpNode("seq_rev_right0", "sequence_reverse")->AsIntermediate(); + auto* seq_rev_right0_out = + VarNode("seq_rev_right0_out") + ->assert_is_op_output("sequence_reverse", "Y") + ->AsIntermediate(); + auto* grnn_right_wh = VarNode("grnn_right_wh") + ->assert_is_op_input("search_grnn", "Wh") + ->AsInput(); + auto* grnn_right_wi = VarNode("grnn_right_wi") + ->assert_is_op_input("search_grnn", "Wi") + ->AsInput(); + auto* grnn_right = OpNode("grnn_right", "search_grnn")->AsIntermediate(); + auto* grnn_right_out = VarNode("grnn_right_out") + ->assert_is_op_output("search_grnn", "Out") + ->AsIntermediate(); + auto* grnn_right_idx_sorted_by_width = + VarNode("grnn_right_idx_sorted_by_width") + ->assert_is_op_output("search_grnn", "idx_sorted_by_width") + ->AsIntermediate(); + auto* grnn_right_layout_input = + VarNode("grnn_right_layout_input") + ->assert_is_op_output("search_grnn", "layout_input") + ->AsIntermediate(); + auto* grnn_right_tmp_buffer = + VarNode("grnn_right_tmp_buffer") + ->assert_is_op_output("search_grnn", "tmp_buffer") + ->AsIntermediate(); + auto* seq_rev_right1 = + OpNode("seq_rev_right1", "sequence_reverse")->AsIntermediate(); + auto* seq_rev_right1_out = + VarNode("seq_rev_right1_out") + ->assert_is_op_output("sequence_reverse", "Y") + ->AsIntermediate(); + auto* seq_pool_right = + OpNode("seq_pool_right", "sequence_pool")->AsIntermediate(); + auto* seq_pool_right_out = VarNode("seq_pool_right_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsOutput(); + auto* seq_pool_right_max_idx = + VarNode("seq_pool_right_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + auto* grnn_left_wh = VarNode("grnn_left_wh") + ->assert_is_op_input("search_grnn", "Wh") + ->AsInput(); + auto* grnn_left_wi = VarNode("grnn_left_wi") + ->assert_is_op_input("search_grnn", "Wi") + ->AsInput(); + auto* grnn_left = OpNode("grnn_left", "search_grnn")->AsIntermediate(); + auto* grnn_left_out = VarNode("grnn_left_out") + ->assert_is_op_output("search_grnn", "Out") + ->AsIntermediate(); + auto* grnn_left_idx_sorted_by_width = + VarNode("grnn_left_idx_sorted_by_width") + ->assert_is_op_output("search_grnn", "idx_sorted_by_width") + ->AsIntermediate(); + auto* grnn_left_layout_input = + VarNode("grnn_left_layout_input") + ->assert_is_op_output("search_grnn", "layout_input") + ->AsIntermediate(); + auto* grnn_left_tmp_buffer = + VarNode("grnn_left_tmp_buffer") + ->assert_is_op_output("search_grnn", "tmp_buffer") + ->AsIntermediate(); + auto* seq_pool_left = + OpNode("seq_pool_left", "sequence_pool")->AsIntermediate(); + auto* seq_pool_left_out = VarNode("seq_pool_left_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsOutput(); + auto* seq_pool_left_max_idx = + VarNode("seq_pool_left_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate(); + auto* concat_2in1_out = VarNode("concat_2in1_out") + ->assert_is_op_output("concat", "Out") + ->AsIntermediate(); + auto* att_2in1_w = + VarNode("att_2in1_w") + ->assert_is_op_input("__xpu__mmdnn_search_attention", "W") + ->AsInput(); + auto* att_2in1_b = + VarNode("att_2in1_b") + ->assert_is_op_input("__xpu__mmdnn_search_attention", "b") + ->AsInput(); + auto* att_2in1 = + OpNode("att_2in1", "__xpu__mmdnn_search_attention")->AsIntermediate(); + auto* att_2in1_out = + VarNode("att_2in1_out") + ->assert_is_op_output("__xpu__mmdnn_search_attention", "Out") + ->AsIntermediate(); + auto* seq_pool_2in1 = + OpNode("seq_pool_2in1", "sequence_pool")->AsIntermediate(); + auto* seq_pool_2in1_out = VarNode("seq_pool_2in1_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsOutput(); + auto* seq_pool_2in1_max_idx = + VarNode("seq_pool_2in1_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + auto* concat_3in1 = OpNode("concat_3in1", "concat")->AsIntermediate(); + auto* concat_3in1_out = VarNode("concat_3in1_out") + ->assert_is_op_output("concat", "Out") + ->AsOutput(); + + *input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out; + *emb_tbl >> *emb0; + *input1 >> *emb1 >> *emb1_out >> *eltwise01; + *emb_tbl >> *emb1; + + *eltwise01_out >> *seq_rev_right0 >> *seq_rev_right0_out >> *grnn_right >> + *grnn_right_out >> *seq_rev_right1 >> *seq_rev_right1_out; + *grnn_right_out >> *seq_pool_right >> *seq_pool_right_out; + *seq_pool_right >> *seq_pool_right_max_idx; + *grnn_right_wh >> *grnn_right; + *grnn_right_wi >> *grnn_right; + *grnn_right >> *grnn_right_idx_sorted_by_width; + *grnn_right >> *grnn_right_layout_input; + *grnn_right >> *grnn_right_tmp_buffer; + + *eltwise01_out >> *grnn_left >> *grnn_left_out >> *seq_pool_left >> + *seq_pool_left_out; + *seq_pool_left >> *seq_pool_left_max_idx; + *grnn_left_wh >> *grnn_left; + *grnn_left_wi >> *grnn_left; + *grnn_left >> *grnn_left_idx_sorted_by_width; + *grnn_left >> *grnn_left_layout_input; + *grnn_left >> *grnn_left_tmp_buffer; + + *seq_rev_right1_out >> *concat_2in1; + *grnn_left_out >> *concat_2in1; + *concat_2in1 >> *concat_2in1_out >> *att_2in1 >> *att_2in1_out >> + *seq_pool_2in1 >> *seq_pool_2in1_out; + *seq_pool_2in1 >> *seq_pool_2in1_max_idx; + *att_2in1_w >> *att_2in1; + *att_2in1_b >> *att_2in1; + + *eltwise01_out >> *concat_3in1; + *seq_rev_right1_out >> *concat_3in1; + *grnn_left_out >> *concat_3in1; + *concat_3in1 >> *concat_3in1_out; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__mmdnn_bid_emb_grnn_att2"); + op_desc.SetInput("id0", {matched.at("input0")->arg()->name}); + op_desc.SetInput("id1", {matched.at("input1")->arg()->name}); + op_desc.SetInput("emb_tbl", {matched.at("emb_tbl")->arg()->name}); + op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_left_wh")->arg()->name}); + op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_left_wi")->arg()->name}); + op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_right_wh")->arg()->name}); + op_desc.SetInput("grnn_rv_wi", {matched.at("grnn_right_wi")->arg()->name}); + op_desc.SetInput("att_fc_w", {matched.at("att_2in1_w")->arg()->name}); + op_desc.SetInput("att_fc_b", {matched.at("att_2in1_b")->arg()->name}); + op_desc.SetOutput("emb0_out", {matched.at("emb0_out")->arg()->name}); + op_desc.SetOutput("grnn_fw_pool_out", + {matched.at("seq_pool_left_out")->arg()->name}); + op_desc.SetOutput("grnn_rv_pool_out", + {matched.at("seq_pool_right_out")->arg()->name}); + op_desc.SetOutput("att_pool_out", + {matched.at("seq_pool_2in1_out")->arg()->name}); + op_desc.SetOutput("concat_3in1_out", + {matched.at("concat_3in1_out")->arg()->name}); + op_desc.SetOutput("emb_fw_out", {matched.at("eltwise01_out")->arg()->name}); + + auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info(); + op_desc.SetAttr>( + "grnn_fw_wh_maxs", + grnn_fw_op_info->GetAttr>("__xpu__wh_max")); + op_desc.SetAttr>( + "grnn_fw_wi_maxs", + grnn_fw_op_info->GetAttr>("__xpu__wi_max")); + auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info(); + op_desc.SetAttr>( + "grnn_rv_wh_maxs", + grnn_rv_op_info->GetAttr>("__xpu__wh_max")); + op_desc.SetAttr>( + "grnn_rv_wi_maxs", + grnn_rv_op_info->GetAttr>("__xpu__wi_max")); auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info(); op_desc.SetAttr("att_fc_w_max", att_fc_op_info->GetAttr("W_max")); @@ -868,6 +1299,9 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { class XPUMmdnnMergeAllFuser : public FuseBase { public: + explicit XPUMmdnnMergeAllFuser(int n_concat_topk) + : n_concat_topk_(n_concat_topk) {} + void BuildPattern() override { auto* concat_7in1_input0 = VarNode("concat_7in1_input0") ->assert_is_op_nth_input("concat", "X", 0) @@ -909,16 +1343,25 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ->assert_is_op_output("relu", "Out") ->AsIntermediate(); - auto* concat_2in1_input0 = VarNode("concat_2in1_input0") + auto* concat_topk_input0 = VarNode("concat_topk_input0") ->assert_is_op_nth_input("concat", "X", 0) ->AsInput(); - auto* concat_2in1_input1 = VarNode("concat_2in1_input1") + auto* concat_topk_input1 = VarNode("concat_topk_input1") ->assert_is_op_nth_input("concat", "X", 1) ->AsInput(); - auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate(); - auto* concat_2in1_out = VarNode("concat_2in1_out") + auto* concat_topk = OpNode("concat_topk", "concat")->AsIntermediate(); + auto* concat_topk_out = VarNode("concat_topk_out") ->assert_is_op_output("concat", "Out") ->AsIntermediate(); + for (int i = 2; i < n_concat_topk_; ++i) { + auto concat_topk_input_name = + paddle::lite::string_format("concat_topk_input%d", i); + auto* concat_topk_inputx = VarNode(concat_topk_input_name) + ->assert_is_op_nth_input("concat", "X", i) + ->AsInput(); + *concat_topk_inputx >> *concat_topk; + } + auto* seq_rev = OpNode("seq_rev", "sequence_reverse")->AsIntermediate(); auto* seq_rev_out = VarNode("seq_rev_out") ->assert_is_op_output("sequence_reverse", "Y") @@ -1034,9 +1477,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase { *search_fc0_w >> *search_fc0; *search_fc0_b >> *search_fc0; - *concat_2in1_input0 >> *concat_2in1; - *concat_2in1_input1 >> *concat_2in1; - *concat_2in1 >> *concat_2in1_out >> *seq_rev >> *seq_rev_out; + *concat_topk_input0 >> *concat_topk; + *concat_topk_input1 >> *concat_topk; + *concat_topk >> *concat_topk_out >> *seq_rev >> *seq_rev_out; *seq_rev_out >> *grnn_rv >> *grnn_rv_out >> *seq_pool_rv >> *seq_pool_rv_out; @@ -1047,7 +1490,7 @@ class XPUMmdnnMergeAllFuser : public FuseBase { *grnn_rv >> *grnn_rv_layout_input; *grnn_rv >> *grnn_rv_tmp_buffer; - *concat_2in1_out >> *grnn_fw >> *grnn_fw_out >> *seq_pool_fw >> + *concat_topk_out >> *grnn_fw >> *grnn_fw_out >> *seq_pool_fw >> *seq_pool_fw_out; *seq_pool_fw >> *seq_pool_fw_max_idx; *grnn_fw_wh >> *grnn_fw; @@ -1075,8 +1518,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase { op_desc.SetType("__xpu__mmdnn_merge_all"); auto* concat_7in1_op_info = matched.at("concat_7in1")->stmt()->op_info(); op_desc.SetInput("concat_7in1_x", concat_7in1_op_info->Input("X")); - auto* concat_2in1_op_info = matched.at("concat_2in1")->stmt()->op_info(); - op_desc.SetInput("concat_2in1_x", concat_2in1_op_info->Input("X")); + auto* concat_topk_op_info = matched.at("concat_topk")->stmt()->op_info(); + op_desc.SetInput("concat_topk_x", concat_topk_op_info->Input("X")); op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_fw_wh")->arg()->name}); op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_fw_wi")->arg()->name}); op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_rv_wh")->arg()->name}); @@ -1093,23 +1536,26 @@ class XPUMmdnnMergeAllFuser : public FuseBase { auto* grnn_fw_op_info = matched.at("grnn_fw")->stmt()->op_info(); op_desc.SetAttr>( "grnn_fw_wh_maxs", - grnn_fw_op_info->GetAttr>("wh_max")); + grnn_fw_op_info->GetAttr>("__xpu__wh_max")); op_desc.SetAttr>( "grnn_fw_wi_maxs", - grnn_fw_op_info->GetAttr>("wi_max")); + grnn_fw_op_info->GetAttr>("__xpu__wi_max")); auto* grnn_rv_op_info = matched.at("grnn_rv")->stmt()->op_info(); op_desc.SetAttr>( "grnn_rv_wh_maxs", - grnn_rv_op_info->GetAttr>("wh_max")); + grnn_rv_op_info->GetAttr>("__xpu__wh_max")); op_desc.SetAttr>( "grnn_rv_wi_maxs", - grnn_rv_op_info->GetAttr>("wi_max")); + grnn_rv_op_info->GetAttr>("__xpu__wi_max")); auto* fc0_op_info = matched.at("search_fc0")->stmt()->op_info(); - op_desc.SetAttr("fc0_w_max", fc0_op_info->GetAttr("w_max")); + op_desc.SetAttr("fc0_w_max", + fc0_op_info->GetAttr("__xpu__w_max")); auto* fc1_op_info = matched.at("search_fc1")->stmt()->op_info(); - op_desc.SetAttr("fc1_w_max", fc1_op_info->GetAttr("w_max")); + op_desc.SetAttr("fc1_w_max", + fc1_op_info->GetAttr("__xpu__w_max")); auto* fc2_op_info = matched.at("search_fc2")->stmt()->op_info(); - op_desc.SetAttr("fc2_w_max", fc2_op_info->GetAttr("w_max")); + op_desc.SetAttr("fc2_w_max", + fc2_op_info->GetAttr("__xpu__w_max")); auto* new_stmt = matched.at("concat_7in1")->stmt(); auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); @@ -1120,8 +1566,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase { new_stmt->SetKernels(std::move(kernels)); std::vector arg_names{ - "concat_2in1_input0", - "concat_2in1_input1", + "concat_topk_input0", + "concat_topk_input1", "grnn_fw_wh", "grnn_fw_wi", "grnn_rv_wh", @@ -1133,6 +1579,11 @@ class XPUMmdnnMergeAllFuser : public FuseBase { "search_fc2_w", "search_fc2_b", }; + for (int i = 2; i < n_concat_topk_; ++i) { + auto concat_topk_input_name = + paddle::lite::string_format("concat_topk_input%d", i); + arg_names.push_back(concat_topk_input_name); + } for (auto name : arg_names) { DirectedLink(matched.at(name), matched.at("concat_7in1")); } @@ -1143,6 +1594,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase { IR_OP_VAR_LINK(matched.at("concat_7in1"), matched.at(name)); } } + + private: + int n_concat_topk_; }; } // namespace fusion @@ -1158,15 +1612,21 @@ class XPUMmdnnFusePass : public ProgramPass { search_att_fuser(graph.get()); fusion::XPUMmdnnMatchConvTopkFuser match_conv_topk_fuser; match_conv_topk_fuser(graph.get()); + fusion::XPUMmdnnMatchConvTopkFuser2 match_conv_topk_fuser2; + match_conv_topk_fuser2(graph.get()); fusion::XPUMmdnnBidSeqRevEmbEltwiseFuser bi_seq_rev_emb_eltwise_fuser; bi_seq_rev_emb_eltwise_fuser(graph.get()); fusion::XPUMmdnnBidEmbGrnnAttFuser bid_emb_grnn_att_fuser; bid_emb_grnn_att_fuser(graph.get()); + fusion::XPUMmdnnBidEmbGrnnAttFuser2 bid_emb_grnn_att_fuser2; + bid_emb_grnn_att_fuser2(graph.get()); fusion::XPUMmdnnBidEmbAttFuser bid_emb_att_fuser; bid_emb_att_fuser(graph.get()); - fusion::XPUMmdnnMergeAllFuser merge_all_fuser; - merge_all_fuser(graph.get()); + for (int n_concat_topk : {3, 2}) { + fusion::XPUMmdnnMergeAllFuser merge_all_fuser(n_concat_topk); + merge_all_fuser(graph.get()); + } } }; @@ -1178,6 +1638,7 @@ REGISTER_MIR_PASS(__xpu__mmdnn_fuse_pass, paddle::lite::mir::XPUMmdnnFusePass) .BindTargets({TARGET(kXPU)}) .BindKernel("__xpu__mmdnn_search_attention") .BindKernel("__xpu__mmdnn_bid_emb_grnn_att") + .BindKernel("__xpu__mmdnn_bid_emb_grnn_att2") .BindKernel("__xpu__mmdnn_bid_emb_att") .BindKernel("__xpu__mmdnn_match_conv_topk") .BindKernel("__xpu__mmdnn_merge_all"); diff --git a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index 04988612192b79824b1294428fa9b1c38d784979..21bc266204d95c0f7faa8c3796e4b6255a3fe741 100644 --- a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -383,10 +383,10 @@ class XPUSingleEncoderFuser : public FuseBase { op_desc.SetAttr("act_type", act_type_); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); - // XXX: memleak? - auto sub_block_desc = new cpp::BlockDesc(); + auto sub_program_desc = std::make_shared(); + sub_program_desc->AddBlock(); static_cast(fake_subgraph_op.get()) - ->SetSubBlock(sub_block_desc); + ->SetProgramDesc(sub_program_desc); auto* single_encoder_stmt = matched.at("q_mul")->stmt(); fake_subgraph_op->Attach(op_desc, single_encoder_stmt->op()->scope()); fake_subgraph_op->SetValidPlaces(single_encoder_stmt->op()->valid_places()); diff --git a/lite/core/mir/fusion/__xpu__resnet_cbam_fuse_pass.cc b/lite/core/mir/fusion/__xpu__resnet_cbam_fuse_pass.cc index b25eb084f286fccfa4afe8832f9dc1ff8384d552..f017cc8c72f93a772f8bcbdc9aa96d5b0ad215d8 100644 --- a/lite/core/mir/fusion/__xpu__resnet_cbam_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__resnet_cbam_fuse_pass.cc @@ -373,10 +373,10 @@ class XPUResNetCbamBlock0Fuser : public FuseBase { auto block0_stmt = matched.at("left_conv1")->stmt(); // block0_stmt->ResetOp(op_desc, graph->valid_places()); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); - // XXX: memleak? - auto sub_block_desc = new cpp::BlockDesc(); + auto sub_program_desc = std::make_shared(); + sub_program_desc->AddBlock(); static_cast(fake_subgraph_op.get()) - ->SetSubBlock(sub_block_desc); + ->SetProgramDesc(sub_program_desc); fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope()); fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places()); block0_stmt->SetOp(fake_subgraph_op); @@ -693,10 +693,10 @@ class XPUResNetCbamBlock1Fuser : public FuseBase { auto block1_stmt = matched.at("right_conv1")->stmt(); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); - // XXX: memleak? - auto sub_block_desc = new cpp::BlockDesc(); + auto sub_program_desc = std::make_shared(); + sub_program_desc->AddBlock(); static_cast(fake_subgraph_op.get()) - ->SetSubBlock(sub_block_desc); + ->SetProgramDesc(sub_program_desc); fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope()); fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places()); block1_stmt->SetOp(fake_subgraph_op); @@ -932,10 +932,10 @@ class XPUResNetCbamBlock2Fuser : public FuseBase { << "Y of last fc must have been transposed"; auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); - // XXX: memleak? - auto sub_block_desc = new cpp::BlockDesc(); + auto sub_program_desc = std::make_shared(); + sub_program_desc->AddBlock(); static_cast(fake_subgraph_op.get()) - ->SetSubBlock(sub_block_desc); + ->SetProgramDesc(sub_program_desc); fake_subgraph_op->Attach(op_desc, scope); fake_subgraph_op->SetValidPlaces(block2_stmt->op()->valid_places()); block2_stmt->SetOp(fake_subgraph_op); diff --git a/lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc b/lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc index de2210a76ea0647cb02131a088ceb754afd0ef9c..7024a872f30d3c78affe82648c902a6128de7070 100644 --- a/lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc @@ -315,10 +315,10 @@ class XPUResNetBlock0Fuser : public FuseBase { auto block0_stmt = matched.at("left_conv1")->stmt(); // block0_stmt->ResetOp(op_desc, graph->valid_places()); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); - // XXX: memleak? - auto sub_block_desc = new cpp::BlockDesc(); + auto sub_program_desc = std::make_shared(); + sub_program_desc->AddBlock(); static_cast(fake_subgraph_op.get()) - ->SetSubBlock(sub_block_desc); + ->SetProgramDesc(sub_program_desc); fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope()); fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places()); block0_stmt->SetOp(fake_subgraph_op); @@ -577,10 +577,10 @@ class XPUResNetBlock1Fuser : public FuseBase { auto block1_stmt = matched.at("right_conv1")->stmt(); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); - // XXX: memleak? - auto sub_block_desc = new cpp::BlockDesc(); + auto sub_program_desc = std::make_shared(); + sub_program_desc->AddBlock(); static_cast(fake_subgraph_op.get()) - ->SetSubBlock(sub_block_desc); + ->SetProgramDesc(sub_program_desc); fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope()); fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places()); block1_stmt->SetOp(fake_subgraph_op); diff --git a/lite/core/mir/fusion/conv_conv_fuse_pass.cc b/lite/core/mir/fusion/conv_conv_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9c4f0c02cd89e04d93af8e4dab71acc5d24e411 --- /dev/null +++ b/lite/core/mir/fusion/conv_conv_fuse_pass.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_conv_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/conv_conv_fuser.h" +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvConvFusePass::Apply(const std::unique_ptr& graph) { + // initialze fuser params + std::vector conv_has_bias_cases{true, false}; + std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; + bool has_arm = false; + for (auto& place : graph->valid_places()) { + if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) { + has_arm = true; + break; + } + } + if (!has_arm) { + return; + } + // only support fp32 fusion + for (auto conv_has_bias0 : conv_has_bias_cases) { + for (auto conv_has_bias1 : conv_has_bias_cases) { + for (auto conv_type0 : conv_type_cases) { + for (auto conv_type1 : conv_type_cases) { + VLOG(4) << "conv_has_bias0:" << conv_has_bias0 + << " conv_type0:" << conv_type0; + VLOG(4) << "conv_has_bias1:" << conv_has_bias1 + << " conv_type1:" << conv_type1; + fusion::ConvConvFuser fuser( + conv_type0, conv_type1, conv_has_bias0, conv_has_bias1); + fuser(graph.get()); + } + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_conv_fuse_pass, paddle::lite::mir::ConvConvFusePass) + .BindTargets({TARGET(kARM)}); diff --git a/lite/kernels/xpu/utils.h b/lite/core/mir/fusion/conv_conv_fuse_pass.h similarity index 76% rename from lite/kernels/xpu/utils.h rename to lite/core/mir/fusion/conv_conv_fuse_pass.h index d410cb1567d5c60aeb52b798d9f17c7f5692e096..64e1b87ec9a8618572d6044f6dde2ab25c5a11c4 100644 --- a/lite/kernels/xpu/utils.h +++ b/lite/core/mir/fusion/conv_conv_fuse_pass.h @@ -14,18 +14,19 @@ #pragma once -#include "lite/backends/xpu/xpu_header_sitter.h" +#include +#include +#include "lite/core/mir/pass.h" namespace paddle { namespace lite { -namespace kernels { -namespace xpu { +namespace mir { -struct XPUFreeDeleter { - void operator()(void* p) const { xpu_free(p); } +class ConvConvFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; }; -} // namespace xpu -} // namespace kernels +} // namespace mir } // namespace lite } // namespace paddle diff --git a/lite/core/mir/fusion/conv_conv_fuser.cc b/lite/core/mir/fusion/conv_conv_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..737f96e69baa8953c0231fcc4c9e104907b17381 --- /dev/null +++ b/lite/core/mir/fusion/conv_conv_fuser.cc @@ -0,0 +1,211 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_conv_fuser.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvConvFuser::BuildPattern() { + auto* conv_input0 = VarNode("conv_input0") + ->assert_is_op_input(conv_type0_, "Input") + ->AsInput(); + auto* conv_weight0 = VarNode("conv_weight0") + ->assert_is_op_input(conv_type0_, "Filter") + ->AsInput(); + auto* conv0 = OpNode("conv2d0", conv_type0_)->assert_is_op(conv_type0_); + auto* conv_out0 = VarNode("conv_out0") + ->assert_is_op_output(conv_type0_, "Output") + ->assert_is_op_input(conv_type1_, "Input") + ->AsIntermediate(); + + auto* conv_weight1 = VarNode("conv_weight1") + ->assert_is_op_input(conv_type1_, "Filter") + ->AsIntermediate(); + auto* conv1 = OpNode("conv2d1", conv_type1_) + ->assert_is_op(conv_type1_) + ->assert_op_attr("groups", 1) + ->AsIntermediate(); + + auto* conv_out1 = VarNode("conv_out1") + ->assert_is_op_output(conv_type1_, "Output") + ->AsOutput(); + + if (conv_has_bias0_) { + if (conv_has_bias1_) { + auto* conv_bias0 = VarNode("conv_bias0") + ->assert_is_op_input(conv_type0_, "Bias") + ->AsIntermediate(); + auto* conv_bias1 = VarNode("conv_bias1") + ->assert_is_op_input(conv_type1_, "Bias") + ->AsInput(); + conv0->LinksFrom({conv_input0, conv_weight0, conv_bias0}) + .LinksTo({conv_out0}); + conv1->LinksFrom({conv_out0, conv_weight1, conv_bias1}) + .LinksTo({conv_out1}); + } else { + auto* conv_bias0 = VarNode("conv_bias0") + ->assert_is_op_input(conv_type0_, "Bias") + ->AsIntermediate(); + conv0->LinksFrom({conv_input0, conv_weight0, conv_bias0}) + .LinksTo({conv_out0}); + conv1->LinksFrom({conv_out0, conv_weight1}).LinksTo({conv_out1}); + } + } else { + conv0->LinksFrom({conv_input0, conv_weight0}).LinksTo({conv_out0}); + if (conv_has_bias1_) { + auto* conv_bias1 = VarNode("conv_bias1") + ->assert_is_op_input(conv_type1_, "Bias") + ->AsInput(); + conv1->LinksFrom({conv_out0, conv_weight1, conv_bias1}) + .LinksTo({conv_out1}); + } else { + conv1->LinksFrom({conv_out0, conv_weight1}).LinksTo({conv_out1}); + } + } +} + +void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto conv_instruct = matched.at("conv2d0")->stmt(); + auto conv_op_desc = conv_instruct->mutable_op_info(); + auto conv = conv_instruct->op(); + auto* scope = conv->scope(); + auto conv_op_desc1 = matched.at("conv2d1")->stmt()->mutable_op_info(); + + // conv0 + auto weight0_t = scope->FindVar(matched.at("conv_weight0")->arg()->name) + ->GetMutable(); + + // conv1 + auto weight1_t = scope->FindVar(matched.at("conv_weight1")->arg()->name) + ->GetMutable(); + // auto groups0 = conv_op_desc->GetAttr("groups"); + auto groups1 = conv_op_desc1->GetAttr("groups"); + auto strides1 = conv_op_desc1->GetAttr>("strides"); + auto paddings1 = conv_op_desc1->GetAttr>("paddings"); + auto dilations1 = conv_op_desc1->GetAttr>("dilations"); + + bool enable0_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; + bool enable1_int8 = conv_op_desc1->HasAttr("enable_int8") ? true : false; + int kw = weight1_t->dims()[2]; + int kh = weight1_t->dims()[3]; + if (!(kw == 1 && kh == 1)) { + return; + } + CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same"; + CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1"; + CHECK_EQ(weight0_t->dims()[0], weight1_t->dims()[1]) + << "weight0_dims[0] == weight1_dim[1]"; + for (int i = 0; i < strides1.size(); i++) { + CHECK_EQ(strides1[i], 1) << "strides[" << i << "]: " << strides1[i] + << " must be 1"; + } + for (int i = 0; i < paddings1.size(); i++) { + CHECK_EQ(paddings1[i], 0) << "paddings1[" << i << "]: " << paddings1[i] + << " must be 0"; + } + for (int i = 0; i < dilations1.size(); i++) { + CHECK_EQ(dilations1[i], 1) << "dilations1[" << i << "]: " << dilations1[i] + << " must be 1"; + } + // comupte new_wight and new bias + /////////////////////////////////////////////////////////////////////////////// + // Compute ConvConvFuser + // Before fusion + // + // conv(x) = conv(x) = kx + z = y + // conv(y) = ay + b + // + // After fusion: + // + // conv(conv(x)) = a(kx + z) + b = akx + az + b + // + // new_weights = ak + // new_bias = az + b + /////////////////////////////////////////////////////////////////////////////// + if (enable0_int8) { + LOG(FATAL) << "it doesn't support"; + return; + } else { + // compute new conv_weight + Tensor weight_tensor; + auto in_dims = weight0_t->dims(); + auto weight_dims = weight1_t->dims(); + const float* din = weight0_t->data(); + const float* weights = weight1_t->data(); + int oc0 = in_dims[0]; + int ic = in_dims[1]; + int ih = in_dims[2]; + int iw = in_dims[3]; + int oc = weight_dims[0]; + weight_tensor.Resize({oc, ic, ih, iw}); + float* dout = weight_tensor.mutable_data(); + ComputeNewWeight(dout, din, weights, oc0, ic, ih, iw, oc); + weight0_t->CopyDataFrom(weight_tensor); + } + // compute new conv_bias + if (conv_has_bias0_ && conv_op_desc->HasInput("Bias") && + conv_op_desc->Input("Bias").size() > 0) { + auto bias_t0 = scope->FindVar(matched.at("conv_bias0")->arg()->name) + ->GetMutable(); + if (conv_has_bias1_ && conv_op_desc1->HasInput("Bias") && + conv_op_desc1->Input("Bias").size() > 0) { + auto bias_t1 = scope->FindVar(matched.at("conv_bias1")->arg()->name) + ->GetMutable(); + Tensor bias; + bias.CopyDataFrom(*bias_t1); + auto bias_data = bias.mutable_data(); + ComputeNewBias(bias_data, bias_t0, weight1_t, bias_t1); + bias_t1->CopyDataFrom(bias); + conv_op_desc->SetInput( + "Bias", {matched.at("conv_bias1")->arg()->name}); // conv_bias + IR_NODE_LINK_TO(matched.at("conv_bias1"), matched.at("conv2d0")); + } else { + Tensor bias; + auto weight_dims = weight1_t->dims(); + bias.Resize({weight_dims[0]}); + auto bias_d = bias.mutable_data(); + ComputeNewBias(bias_d, bias_t0, weight1_t, nullptr); + bias_t0->CopyDataFrom(bias); + conv_op_desc->SetInput( + "Bias", {matched.at("conv_bias0")->arg()->name}); // conv_bias + } + } else { + if (conv_has_bias1_ && conv_op_desc1->HasInput("Bias") && + conv_op_desc1->Input("Bias").size() > 0) { + conv_op_desc->SetInput( + "Bias", {matched.at("conv_bias1")->arg()->name}); // conv_bias + IR_NODE_LINK_TO(matched.at("conv_bias1"), matched.at("conv2d0")); + } + } + conv_op_desc->SetType(conv_type0_); + conv_op_desc->SetInput("Input", {matched.at("conv_input0")->arg()->name}); + conv_op_desc->SetInput("Filter", {matched.at("conv_weight0")->arg()->name}); + conv_op_desc->SetOutput("Output", {matched.at("conv_out1")->arg()->name}); + + auto update_conv_desc = *conv_instruct->mutable_op_info(); + conv_instruct->ResetOp(update_conv_desc, graph->valid_places()); + + IR_OP_VAR_LINK(matched.at("conv2d0"), matched.at("conv_out1")); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_conv_fuser.h b/lite/core/mir/fusion/conv_conv_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..5d1f58d1c8746a137e2078006016ec6007c2afbb --- /dev/null +++ b/lite/core/mir/fusion/conv_conv_fuser.h @@ -0,0 +1,120 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvConvFuser : public FuseBase { + public: + explicit ConvConvFuser(const std::string& conv_type0, + const std::string& conv_type1, + const bool conv_has_bias0, + const bool conv_has_bias1) + : conv_type0_(conv_type0), + conv_type1_(conv_type1), + conv_has_bias0_(conv_has_bias0), + conv_has_bias1_(conv_has_bias1) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + void ComputeNewWeight(float* dout, + const float* din, + const float* weights, + int oc0, + int ic, + int ih, + int iw, + int oc1) { + // input conv_weight0_t weights conv_weight1_t + // output weight_tensor + // ksize = 1 + int in_size = ih * iw; + int in_channel_size = ic * in_size; + // out = w1[j, i, ih, iw] * w2[k, j, kw, kh] + // out_dim = [oc1, ic, kh, kw], din_dim = [oc0, ic, kh, kw] + // weight_dim = [oc1, oc0, kh, kw] + for (int k = 0; k < oc1; k++) { + const float* weights_ptr = weights + k * oc0; + float* out_ptr = dout + k * in_channel_size; + for (int c = 0; c < ic; c++) { + float* out_ptr_channel = out_ptr + c * in_size; + const float* din_ptr = din + c * in_size; + for (int i = 0; i < in_size; i++) { + float sum = 0.f; + for (int j = 0; j < oc0; j++) { + sum += din_ptr[j * in_channel_size] * weights_ptr[j]; + } + *out_ptr_channel++ = sum; + } + } + } + } + + void ComputeNewBias(float* dout, + Tensor* bias0_tensor, + Tensor* weight_tensor, + Tensor* bias1_tensor) { + // input bias0_tensor weight_tensor bias1_tensor + // output bias_tensor + auto in_dims = bias0_tensor->dims(); + auto weight_dims = weight_tensor->dims(); + const float* din = bias0_tensor->data(); + const float* weights = weight_tensor->data(); + int ic = in_dims[0]; + int oc = weight_dims[0]; + // out_k = b0[num, j, 1, 1] * w2[k, j, 1, 1] + if (bias1_tensor) { + const float* din2 = bias1_tensor->data(); + for (int k = 0; k < oc; k++) { + const float* weights_ptr = weights + k * ic; + float sum = 0.f; + for (int j = 0; j < ic; j++) { + sum += din[j] * weights_ptr[j]; + } + dout[k] = sum + din2[k]; + } + } else { + for (int k = 0; k < oc; k++) { + const float* weights_ptr = weights + k * ic; + float sum = 0.f; + for (int j = 0; j < ic; j++) { + sum += din[j] * weights_ptr[j]; + } + dout[k] = sum; + } + } + } + + private: + std::string conv_type0_{"conv2d"}; + std::string conv_type1_{"conv2d"}; + bool conv_has_bias0_{false}; + bool conv_has_bias1_{false}; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 1335518b00db5311b4605148817faed52164fd7a..76796468da3565733143898c1ded65ea853fac3c 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -175,7 +175,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, for (int i = 0; i < weight_scale_size; i++) { weight_scale.push_back(whole_weight_scale); } - op_desc.SetAttr("enable_int8", true); + + // Arm CPU does not support conv2d_transpose + if (quantized_op_type_ != "conv2d_transpose") { + op_desc.SetAttr("enable_int8", true); + } op_desc.SetInputScale(weight_name, weight_scale); // change the weight from the float type to int8 type. @@ -280,6 +284,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); } + // Arm CPU does not support conv2d_transpose if (quantized_op_type_ != "conv2d_transpose") { op_desc.SetAttr("enable_int8", true); } diff --git a/lite/core/mir/generate_program_pass.cc b/lite/core/mir/generate_program_pass.cc index d7486c0933dbbe74115bd6358962817b2b946c12..3c9bac1c5b9fbf6d48683f6423a4c670b17cb127 100644 --- a/lite/core/mir/generate_program_pass.cc +++ b/lite/core/mir/generate_program_pass.cc @@ -39,6 +39,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr& graph) { nodes_in_order = graph->StmtTopologicalOrder(); } + insts_.emplace_back(); for (auto& item : nodes_in_order) { if (item->IsStmt()) { auto& stmt = item->AsStmt(); @@ -57,7 +58,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr& graph) { .SetSyncStreams(stmt.sync_streams_); } #endif - insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); + insts_.back().emplace_back(stmt.op(), std::move(stmt.kernels().front())); } } } diff --git a/lite/core/mir/generate_program_pass.h b/lite/core/mir/generate_program_pass.h index b126b4aba4d09a95a0033b04ed241812c88a3287..2ef4d035710d9542b365789aeabe8a08537ff225 100644 --- a/lite/core/mir/generate_program_pass.h +++ b/lite/core/mir/generate_program_pass.h @@ -42,7 +42,7 @@ class GenerateProgramPass : public ProgramPass { } private: - std::vector insts_; + std::vector> insts_; }; } // namespace mir diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index 46738dd49c16fd9736d61711b4baf56d51247699..e09220d083ee8241001b6d9d55fb48eb1ba74f2e 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -284,13 +284,19 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph, head_node->AsArg().name, cur_node->AsArg().name); // for subgraph op, modify the BlockDesc - auto* sub_block_desc = dynamic_cast( - inst_node->AsStmt().op().get()) - ->GetSubBlock(); - for (size_t i = 0; i < sub_block_desc->OpsSize(); ++i) { - auto* sub_block_op_desc = sub_block_desc->GetOp(i); - UpdateInputTo( - sub_block_op_desc, head_node->AsArg().name, cur_node->AsArg().name); + auto sub_program_desc = dynamic_cast( + inst_node->AsStmt().op().get()) + ->GetProgramDesc(); + CHECK(sub_program_desc); + int sub_block_idx = + inst_node->AsStmt().op()->op_info()->GetAttr("sub_block"); + auto* sub_block_desc = + sub_program_desc->GetBlock(sub_block_idx); + for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize(); + ++sub_op_idx) { + auto* sub_op_desc = const_cast( + sub_block_desc->GetOp(sub_op_idx)); + UpdateInputTo(sub_op_desc, head_node->AsArg().name, cur_node->AsArg().name); } // recreate the op @@ -444,21 +450,27 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph, tail_node->AsArg().name, cur_node->AsArg().name); // for subgraph op, modify the BlockDesc - auto* sub_block_desc = dynamic_cast( - inst_node->AsStmt().op().get()) - ->GetSubBlock(); - for (size_t i = 0; i < sub_block_desc->OpsSize(); ++i) { - auto* sub_block_op_desc = sub_block_desc->GetOp(i); + auto sub_program_desc = dynamic_cast( + inst_node->AsStmt().op().get()) + ->GetProgramDesc(); + CHECK(sub_program_desc); + int sub_block_idx = + inst_node->AsStmt().op()->op_info()->GetAttr("sub_block"); + auto* sub_block_desc = + sub_program_desc->GetBlock(sub_block_idx); + for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize(); + ++sub_op_idx) { + auto* sub_op_desc = const_cast( + sub_block_desc->GetOp(sub_op_idx)); UpdateOutputTo( - sub_block_op_desc, tail_node->AsArg().name, cur_node->AsArg().name); + sub_op_desc, tail_node->AsArg().name, cur_node->AsArg().name); /* graph like this * subgraph_op_0 * / \ * / \ * subgraph_op_1 host_op */ - UpdateInputTo( - sub_block_op_desc, tail_node->AsArg().name, cur_node->AsArg().name); + UpdateInputTo(sub_op_desc, tail_node->AsArg().name, cur_node->AsArg().name); } // recreate the op @@ -482,15 +494,22 @@ void MLUPostprocessPass::RecreateOp(Node* inst_node, SSAGraph* graph) { } } -bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node, Node* inst) { - auto* block_desc = - static_cast(inst->AsStmt().op().get()) - ->GetSubBlock(); - for (size_t op_idx = 0; op_idx < block_desc->OpsSize(); op_idx++) { - auto op_desc = block_desc->GetOp(op_idx); - CHECK(op_desc); - if (op_desc->Type() == "conv2d") { - for (auto& names : op_desc->inputs()) { +bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node, + Node* inst_node) { + auto sub_program_desc = dynamic_cast( + inst_node->AsStmt().op().get()) + ->GetProgramDesc(); + CHECK(sub_program_desc); + int sub_block_idx = + inst_node->AsStmt().op()->op_info()->GetAttr("sub_block"); + auto* sub_block_desc = + sub_program_desc->GetBlock(sub_block_idx); + for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize(); + sub_op_idx++) { + auto sub_op_desc = sub_block_desc->GetOp(sub_op_idx); + CHECK(sub_op_desc); + if (sub_op_desc->Type() == "conv2d") { + for (auto& names : sub_op_desc->inputs()) { if (std::find(names.second.begin(), names.second.end(), arg_node->AsArg().name) != names.second.end()) { @@ -746,19 +765,23 @@ std::pair CheckOutputAndInsert( // insert cast op on mlu, to avoid cast on cpu void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, const Type* subgraph_type) { - auto subgraph_op = subgraph_node->AsStmt().op(); - CHECK_EQ(subgraph_op->Type(), "subgraph"); - auto op = dynamic_cast(subgraph_op.get()); - CHECK(op); - auto block_desc = op->GetSubBlock(); + CHECK_EQ(subgraph_node->AsStmt().op()->Type(), "subgraph"); + auto subgraph_op = + dynamic_cast(subgraph_node->AsStmt().op().get()); + CHECK(subgraph_op); + auto sub_program_desc = subgraph_op->GetProgramDesc(); + CHECK(sub_program_desc); + int sub_block_idx = subgraph_op->op_info()->GetAttr("sub_block"); + auto* sub_block_desc = const_cast( + sub_program_desc->GetBlock(sub_block_idx)); // create a new block desc to keep op sequence correct - cpp::BlockDesc* new_block_desc = new cpp::BlockDesc(); - new_block_desc->ClearOps(); - new_block_desc->ClearVars(); - new_block_desc->SetIdx(block_desc->Idx()); - new_block_desc->SetParentIdx(block_desc->ParentIdx()); - new_block_desc->SetForwardBlockIdx(block_desc->ForwardBlockIdx()); + cpp::BlockDesc new_block_desc; + new_block_desc.ClearOps(); + new_block_desc.ClearVars(); + new_block_desc.SetIdx(sub_block_desc->Idx()); + new_block_desc.SetParentIdx(sub_block_desc->ParentIdx()); + new_block_desc.SetForwardBlockIdx(sub_block_desc->ForwardBlockIdx()); // find all IO that is not weight or persist std::list i_names, o_names; @@ -769,8 +792,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, auto input_name = input->AsArg().name; if (!(input->AsArg().is_weight || input->AsArg().is_persist)) { i_names.emplace_back(input_name); - auto ret = CheckInputAndInsert(op->scope(), - new_block_desc, + auto ret = CheckInputAndInsert(subgraph_op->scope(), + &new_block_desc, input_name, input->AsArg().type, subgraph_type); @@ -783,8 +806,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, auto output_name = output->AsArg().name; if (!(output->AsArg().is_weight || output->AsArg().is_persist)) { o_names.emplace_back(output_name); - auto ret = CheckOutputAndInsert(op->scope(), - block_desc, + auto ret = CheckOutputAndInsert(subgraph_op->scope(), + sub_block_desc, output_name, output->AsArg().type, subgraph_type); @@ -795,46 +818,48 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, } // update input and output - for (size_t op_idx = 0; op_idx < block_desc->OpsSize(); ++op_idx) { - auto desc = block_desc->GetOp(op_idx); - auto new_desc = new_block_desc->AddOp(); - *new_desc = *desc; - - if (desc->Type() != "layout" && desc->Type() != "cast") { - auto op_input_args = new_desc->InputArgumentNames(); + for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize(); + ++sub_op_idx) { + auto sub_op_desc = sub_block_desc->GetOp(sub_op_idx); + auto new_op_desc = new_block_desc.AddOp(); + *new_op_desc = *sub_op_desc; + + if (sub_op_desc->Type() != "layout" && sub_op_desc->Type() != "cast") { + auto op_input_args = new_op_desc->InputArgumentNames(); for (auto& input_arg : op_input_args) { - auto op_input = new_desc->Input(input_arg); + auto op_input = new_op_desc->Input(input_arg); for (auto& it : i_names) { auto index = std::find(op_input.begin(), op_input.end(), it); if (index != op_input.end() && node_replace.find(it) != node_replace.end()) { index = op_input.erase(index); op_input.emplace(index, node_replace.at(it)); - VLOG(4) << new_desc->Type() << "] change input from " << it + VLOG(4) << new_op_desc->Type() << "] change input from " << it << " to " << node_replace.at(it); } } - new_desc->SetInput(input_arg, op_input); + new_op_desc->SetInput(input_arg, op_input); } - auto op_output_args = new_desc->OutputArgumentNames(); + auto op_output_args = new_op_desc->OutputArgumentNames(); for (auto& output_arg : op_output_args) { - auto op_output = new_desc->Output(output_arg); + auto op_output = new_op_desc->Output(output_arg); for (auto& it : o_names) { auto index = std::find(op_output.begin(), op_output.end(), it); if (index != op_output.end() && node_replace.find(it) != node_replace.end()) { index = op_output.erase(index); op_output.emplace(index, node_replace.at(it)); - VLOG(4) << new_desc->Type() << "] change output from " << it + VLOG(4) << new_op_desc->Type() << "] change output from " << it << " to " << node_replace.at(it); } } - new_desc->SetOutput(output_arg, op_output); + new_op_desc->SetOutput(output_arg, op_output); } } } - op->SetSubBlock(new_block_desc); + + *sub_block_desc = new_block_desc; } void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) { diff --git a/lite/core/mir/ssa_graph.cc b/lite/core/mir/ssa_graph.cc index f8991a359b177799cc5f59651c5d305fe64231ef..9cf7bc8995766e47895ce3dd2ef6bf7bcb614e5c 100644 --- a/lite/core/mir/ssa_graph.cc +++ b/lite/core/mir/ssa_graph.cc @@ -153,60 +153,61 @@ Node *SSAGraph::GraphCreateInstructNode( } void SSAGraph::Build(const Program &program, - const std::vector &valid_places) { + const std::vector &valid_places, + int block_idx) { CHECK(node_storage_.empty()); - auto weights_name = program.weights(); - auto is_weights = [&](const std::string &name) -> bool { - auto it = std::find(weights_name.begin(), weights_name.end(), name); - if (it == weights_name.end()) return false; + auto weights = program.weights(); + auto is_weight = [&](const std::string &name) -> bool { + auto it = std::find(weights.begin(), weights.end(), name); + if (it == weights.end()) return false; return true; }; - std::map var_types = program.var_data_type(); - - std::map arg_update_node_map_; - for (auto &op : program.ops()) { + auto var_type_map = program.var_type_map(); + std::map arg_update_node_map; + for (auto &op : program.ops(block_idx)) { VLOG(3) << op->op_info()->Type(); auto *op_node = GraphCreateInstructNode(op, valid_places); - for (const std::string &name : op->op_info()->input_names()) { + auto *op_info = op->op_info(); + const auto &op_type = op_info->Type(); + for (const auto &var_name : op_info->input_names()) { mir::Node *arg_node = nullptr; - if (arg_update_node_map_.count(name)) { - arg_node = arg_update_node_map_.at(name); + if (arg_update_node_map.count(var_name)) { + arg_node = arg_update_node_map.at(var_name); } else { node_storage_.emplace_back(); arg_node = &node_storage_.back(); - arg_node->AsArg(name, node_storage_.size() - 1); - arg_update_node_map_[name] = arg_node; + arg_node->AsArg(var_name, node_storage_.size() - 1); + arg_update_node_map[var_name] = arg_node; } - if (var_types.count(name)) { + if (var_type_map.count(var_name)) { if (!arg_node->arg()->type) { - arg_node->arg()->type = LiteType::GetTensorTy( - TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + arg_node->arg()->type = var_type_map[var_name]; } // Store the original data type of the output tensors for // type_precision_cast_pass, to keep the consistency between the // output types of original graph and optimized graph's - if (op->op_info()->Type() == "fetch") { + if (op_type == "fetch") { op->mutable_op_info()->SetAttr( - "data_type", static_cast(var_types[name])); + "data_type", + static_cast(var_type_map[var_name]->precision())); } } - if (is_weights(name)) arg_node->AsArg().is_weight = true; + if (is_weight(var_name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); DirectedLink(arg_node, op_node); } - for (const std::string &name : op->op_info()->output_names()) { + for (const auto &var_name : op->op_info()->output_names()) { node_storage_.emplace_back(); auto *arg_node = &node_storage_.back(); - arg_node->AsArg(name, node_storage_.size() - 1); - arg_update_node_map_[name] = arg_node; - if (var_types.count(name) && !arg_node->arg()->type) { - arg_node->arg()->type = LiteType::GetTensorTy( - TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + arg_node->AsArg(var_name, node_storage_.size() - 1); + arg_update_node_map[var_name] = arg_node; + if (var_type_map.count(var_name) && !arg_node->arg()->type) { + arg_node->arg()->type = var_type_map[var_name]; } - if (is_weights(name)) arg_node->AsArg().is_weight = true; + if (is_weight(var_name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); DirectedLink(op_node, arg_node); } diff --git a/lite/core/mir/ssa_graph.h b/lite/core/mir/ssa_graph.h index e2967cf96a6b00ccc225ce05b043cb94f161b1d6..819b0a71ea1be04c85316e90001aef311b7d7238 100644 --- a/lite/core/mir/ssa_graph.h +++ b/lite/core/mir/ssa_graph.h @@ -35,9 +35,13 @@ class GraphBase {}; class SSAGraph : GraphBase { public: - // @param program: the op program + // @param program: the target program with vars and ops // @param valid_places: the valid places user set for the system. - void Build(const Program &program, const std::vector &valid_places); + // @param block_idx: the block index in the target program, default is 0(main + // block) + void Build(const Program &program, + const std::vector &valid_places, + int block_idx = kRootBlockIdx); void RemoveNode(const mir::Node *node); std::vector StmtTopologicalOrder(); diff --git a/lite/core/mir/subgraph/subgraph_detector.cc b/lite/core/mir/subgraph/subgraph_detector.cc index 4b9f34225f70e9050b2605b49e888ed323536b2f..13805b2b18634551d4b74ac436954fa8f6b9ed05 100644 --- a/lite/core/mir/subgraph/subgraph_detector.cc +++ b/lite/core/mir/subgraph/subgraph_detector.cc @@ -411,16 +411,17 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, cpp::OpDesc subgraph_op_desc; subgraph_op_desc.SetType("subgraph"); - // Create a new sub block desc for storing all of Ops and Vars of the target - // subgraph and sub_block_idx is set as a attribute of subgraph op, - // sub_block_idx < 0 means it's a new subgraph op - int sub_block_idx = -(subgraph_idx + 1); - auto sub_block_desc = new cpp::BlockDesc(); + // Create a program desc and a block desc for storing all of Ops and Vars of + // the target subgraph and sub_block_idx is set as a attribute of subgraph op, + // sub_block_idx = 0 means it's a new subgraph op + auto sub_program_desc = std::make_shared(); + int sub_block_idx = 0; + auto sub_block_desc = sub_program_desc->AddBlock(); sub_block_desc->ClearOps(); sub_block_desc->ClearVars(); for (auto &op_node : subgraph_nodes) { - auto sub_block_op_desc = sub_block_desc->AddOp(); - *sub_block_op_desc = *op_node->AsStmt().op_info(); + auto sub_op_desc = sub_block_desc->AddOp(); + *sub_op_desc = *op_node->AsStmt().op_info(); } subgraph_op_desc.SetAttr("sub_block", sub_block_idx); @@ -437,13 +438,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, &local_var_nodes, &unused_var_nodes); // A simplified model without the original weight/local/unused nodes on the - // subgraph ops will be saved only if 'SUBGRAPH_DISABLE_ONLINE_MODE' is set to - // true and Predictor->Run(...), Predictor->Save(...) is called. + // subgraph ops will be saved only if 'SUBGRAPH_ONLINE_MODE' is set to + // true(default) and Predictor->Run(...), Predictor->Save(...) is called. std::set input_var_nodes(idata_var_nodes.begin(), idata_var_nodes.end()); std::set output_var_nodes(odata_var_nodes.begin(), odata_var_nodes.end()); - if (!GetBoolFromEnv(SUBGRAPH_DISABLE_ONLINE_MODE)) { + if (GetBoolFromEnv(SUBGRAPH_ONLINE_MODE, true)) { input_var_nodes.insert(weight_var_nodes.begin(), weight_var_nodes.end()); output_var_nodes.insert(local_var_nodes.begin(), local_var_nodes.end()); output_var_nodes.insert(unused_var_nodes.begin(), unused_var_nodes.end()); @@ -476,7 +477,7 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, subgraph_op_desc.SetOutput("Outputs", output_var_names); auto subgraph_op = LiteOpRegistry::Global().Create("subgraph"); static_cast(subgraph_op.get()) - ->SetSubBlock(sub_block_desc); + ->SetProgramDesc(sub_program_desc); auto any_op = (*subgraph_nodes.begin())->AsStmt().op(); subgraph_op->Attach(subgraph_op_desc, any_op->scope()); diff --git a/lite/core/mir/subgraph/subgraph_detector_test.cc b/lite/core/mir/subgraph/subgraph_detector_test.cc index 06c9c4c78fedba7cfabcd4ff2dd3804b404f966d..872556518c002f144c7a7d3b88d683e8d484f03e 100644 --- a/lite/core/mir/subgraph/subgraph_detector_test.cc +++ b/lite/core/mir/subgraph/subgraph_detector_test.cc @@ -141,12 +141,11 @@ std::vector AddFetchDesc( } TEST(Subgraph, detect_simple_model) { - cpp::ProgramDesc program_desc; + auto program_desc = std::make_shared(); std::vector valid_places{{TARGET(kHost), PRECISION(kFloat)}}; auto scope = std::make_shared(); // Build a simple network - program_desc.ClearBlocks(); - auto* block_desc = program_desc.AddBlock(); + auto* block_desc = program_desc->AddBlock(); block_desc->ClearOps(); block_desc->ClearVars(); auto* var_desc = block_desc->AddVar(); @@ -181,13 +180,13 @@ TEST(Subgraph, detect_custom_model) { "the path of model files."; return; } - cpp::ProgramDesc program_desc; + auto program_desc = std::make_shared(); auto scope = std::make_shared(); LoadModelPb(FLAGS_model_dir, FLAGS_model_file, FLAGS_params_file, scope.get(), - &program_desc, + program_desc.get(), !FLAGS_model_file.empty() && !FLAGS_params_file.empty(), false); std::vector valid_places({ diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index 39a94cbca6bd6222da5da1d314ea07475592bf0e..40ece35993cfd2f8bce07e605387741202973614 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -36,14 +36,20 @@ void UpdateInputsForSubgraph(OpLite* op, op_desc->GetAttr>("input_data_names"); std::replace(input_data_names.begin(), input_data_names.end(), from, to); op_desc->SetAttr("input_data_names", input_data_names); - auto* subblock_desc = static_cast(op)->GetSubBlock(); - CHECK(subblock_desc); - for (size_t i = 0; i < subblock_desc->OpsSize(); i++) { - auto* subblock_op_desc = subblock_desc->GetOp(i); - for (auto& subblock_op_input : *subblock_op_desc->mutable_inputs()) { - for (auto& subblock_var_name : subblock_op_input.second) { - if (subblock_var_name == from) { - subblock_var_name = to; + auto sub_program_desc = + static_cast(op)->GetProgramDesc(); + CHECK(sub_program_desc); + int sub_block_idx = op_desc->GetAttr("sub_block"); + auto sub_block_desc = + sub_program_desc->GetBlock(sub_block_idx); + for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize(); + sub_op_idx++) { + auto sub_op_desc = const_cast( + sub_block_desc->GetOp(sub_op_idx)); + for (auto& sub_op_input : *sub_op_desc->mutable_inputs()) { + for (auto& sub_var_name : sub_op_input.second) { + if (sub_var_name == from) { + sub_var_name = to; } } } diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h index d9f420cfad90d3c6a1f08072d8c5f87d2326661a..f7d35bfef3ac53903448c48300c144f8fd15652d 100644 --- a/lite/core/mir/variable_place_inference_pass.h +++ b/lite/core/mir/variable_place_inference_pass.h @@ -59,25 +59,46 @@ class VariablePlaceInferencePass : public DebugPass { } // Set the type of the weight - void SetWeightType(Node* w, + void SetWeightType(Node* weight_node, const LiteType& type, - const std::map& lite_with_targets) { + const std::map& with_targets) { VLOG(4) << "type.precision():" << PrecisionRepr(type.precision()); - if (lite_with_targets.at("kFPGA")) { - w->AsArg().type = LiteType::GetTensorTy( + if (with_targets.at("kFPGA")) { + weight_node->AsArg().type = LiteType::GetTensorTy( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); - } else if (lite_with_targets.at("kOpenCL")) { - w->AsArg().type = LiteType::GetTensorTy( + } else if (with_targets.at("kOpenCL")) { + weight_node->AsArg().type = LiteType::GetTensorTy( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); - } else if (lite_with_targets.at("kCUDA")) { - w->AsArg().type = LiteType::GetTensorTy( + } else if (with_targets.at("kCUDA")) { + weight_node->AsArg().type = LiteType::GetTensorTy( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); } else { - w->AsArg().type = LiteType::GetTensorTy( + weight_node->AsArg().type = LiteType::GetTensorTy( TARGET(kHost), type.precision(), DATALAYOUT(kNCHW)); } } + // Update a's kUnk fields from b's fields. + void UpdateTypeFrom(const Type** a, const Type* b) { + auto target = (*a)->target(); + auto precision = (*a)->precision(); + auto layout = (*a)->layout(); + if (target == TARGET(kUnk)) { + target = b->target(); + } + if (precision == PRECISION(kUnk)) { + precision = b->precision(); + } + if (layout == DATALAYOUT(kUnk)) { + layout = b->layout(); + } + if ((*a)->IsTensor() && b->IsTensor()) { + *a = LiteType::GetTensorTy(target, precision, layout); + } else if ((*a)->IsTensorList() && b->IsTensorList()) { + *a = LiteType::GetTensorListTy(target, precision, layout); + } + } + void InferenceArgumentPlace(SSAGraph* graph) { auto& valid_places = graph->valid_places(); auto valid_places_has_target = [&](TargetType t) -> bool { @@ -88,122 +109,90 @@ class VariablePlaceInferencePass : public DebugPass { } return false; }; - std::map lite_with_targets{ + std::map with_targets{ {"kOpenCL", valid_places_has_target(TARGET(kOpenCL))}, {"kCUDA", valid_places_has_target(TARGET(kCUDA))}, {"kFPGA", valid_places_has_target(TARGET(kFPGA))}}; - VLOG(4) << "lite_with_targets['kOpenCL']:" << lite_with_targets["kOpenCL"]; - VLOG(4) << "lite_with_targets['kFPGA']:" << lite_with_targets["kFPGA"]; + VLOG(4) << "with_targets['kOpenCL']:" << with_targets["kOpenCL"]; + VLOG(4) << "with_targets['kFPGA']:" << with_targets["kFPGA"]; VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); - for (auto& x : graph->StmtTopologicalOrder()) { - auto& inst = x->AsStmt(); + for (auto& node : graph->StmtTopologicalOrder()) { + auto& inst = node->AsStmt(); + const auto* op_info = inst.op_info(); + const auto& op_type = op_info->Type(); + auto& kernel = inst.picked_kernel(); + // The IoCopyOp is a tool operator, it won't support the type inference. // in fpga, we has io_copy+cali+layout tool ops, so we need type inference - // for - // tool operator - if ((!lite_with_targets["kFPGA"]) && (!lite_with_targets["kOpenCL"])) { - VLOG(3) << "inst.op_type() == 'io_copy', continue"; - if (inst.op_type() == "io_copy") continue; + // for tool operator + if ((!with_targets["kFPGA"]) && (!with_targets["kOpenCL"])) { + VLOG(3) << "skip 'io_copy' if target is FPGA and OpenCL"; + if (op_type == "io_copy") continue; } - // deal with inputs - VLOG(4) << "Infering op " << inst.op_info()->Repr(); - // TODO(zhaolong): Add check if the node's name in op's arguments. - auto get_argname = [&]( - const std::string& node_name, - const std::map>& argname_map) - -> std::string { - for (auto& ele : argname_map) { - auto it = - std::find(ele.second.begin(), ele.second.end(), node_name); - if (it != ele.second.end()) return ele.first; - } - return ""; - }; - - for (auto* x_in : x->inlinks) { - std::string node_name = x_in->AsArg().name; - std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); - CHECK(arg_name.size() > 0) << "can not found op arguments for node " - << node_name; - VLOG(4) << "-- input arg_name:" << arg_name << " " - << "-- node name:" << node_name; - auto type = inst.picked_kernel().GetInputDeclType(arg_name); - if (!x_in->AsArg().type) { - VLOG(4) << "set type " << *type << " " << x_in->AsArg().name; - if (x_in->AsArg().is_weight) { - SetWeightType(x_in, *type, lite_with_targets); + // Infering the input and output variable's place according to the + // declaration of I/O arguments of the picked kernel of the op + VLOG(4) << "Op " << op_info->Repr(); + for (auto* in_node : node->inlinks) { + auto& var = in_node->AsArg(); + const auto& var_name = var.name; + auto* var_type = &var.type; + std::string arg_name; + CHECK(op_info->GetInputArgname(var_name, &arg_name)) + << "Can not find the input argument for var " << var_name; + VLOG(4) << " - input arg name:" << arg_name << " var name:" << var_name; + const auto* decl_type = kernel.GetInputDeclType(arg_name); + if (!(*var_type)) { + VLOG(4) << "set type " << *decl_type << " " << var_name; + if (var.is_weight) { + SetWeightType(in_node, *decl_type, with_targets); } else { - x_in->AsArg().type = type; + *var_type = decl_type; } - } else if (x_in->AsArg().type->target() == TARGET(kUnk) && - x_in->AsArg().type->precision() != PRECISION(kUnk) && - x_in->AsArg().type->layout() == DATALAYOUT(kUnk)) { + } else if (!(*var_type)->place().is_valid()) { // If is quantization, infer the Int8 type. - if (type->precision() == PRECISION(kInt8)) { - x_in->AsArg().type = type; + if (decl_type->precision() == PRECISION(kInt8)) { + *var_type = decl_type; } else { - PrecisionType tmp_ptype = x_in->AsArg().type->precision(); - x_in->AsArg().type = LiteType::GetTensorTy( - type->target(), tmp_ptype, type->layout()); + UpdateTypeFrom(var_type, decl_type); } } } - - VLOG(4) << "inst " << inst.op_info()->Repr(); - for (auto* x_out : x->outlinks) { - std::string node_name = x_out->AsArg().name; - std::string arg_name = - get_argname(node_name, inst.op_info()->outputs()); - CHECK(arg_name.size() > 0) << "can not found op arguments for node " - << node_name << " in Inst " - << inst.op_type(); - VLOG(4) << "-- output arg_name " << arg_name; - auto type = inst.picked_kernel().GetOutputDeclType(arg_name); - if (!x_out->AsArg().type) { - VLOG(4) << "set type " << *type << " " << x_out->AsArg().name; - if (x_out->AsArg().is_weight) { - SetWeightType(x_out, *type, lite_with_targets); + for (auto* out_node : node->outlinks) { + auto& var = out_node->AsArg(); + const auto& var_name = var.name; + auto* var_type = &var.type; + std::string arg_name; + CHECK(op_info->GetOutputArgname(var_name, &arg_name)) + << "Can not find the output argument for var " << var_name; + VLOG(4) << " - output arg name:" << arg_name + << " var name:" << var_name; + const auto* decl_type = kernel.GetOutputDeclType(arg_name); + if (!(*var_type)) { + VLOG(4) << "set type " << *decl_type << " " << var_name; + if (var.is_weight) { + SetWeightType(out_node, *decl_type, with_targets); } else { - x_out->AsArg().type = type; + *var_type = decl_type; } - } else if (x_out->AsArg().type->target() == TARGET(kUnk) && - x_out->AsArg().type->precision() != PRECISION(kUnk) && - x_out->AsArg().type->layout() == DATALAYOUT(kUnk)) { + } else if (!(*var_type)->place().is_valid()) { // If is quantization, infer the Int8 type. - if (type->precision() == PRECISION(kInt8)) { - x_out->AsArg().type = type; - } else if (type->precision() == PRECISION(kFP16) && - type->target() != TARGET(kOpenCL)) { - x_out->AsArg().type = type; + if (decl_type->precision() == PRECISION(kInt8) || + (decl_type->precision() == PRECISION(kFP16) && + decl_type->target() != TARGET(kOpenCL))) { + *var_type = decl_type; } else { - PrecisionType tmp_ptype = x_out->AsArg().type->precision(); - x_out->AsArg().type = LiteType::GetTensorTy( - type->target(), tmp_ptype, type->layout()); + UpdateTypeFrom(var_type, decl_type); } } } } } - // Update me's kUnk fields by other's fields. - void UpdatePlace(Place* me, const Place& other) { - CHECK(other.is_valid()); - if (me->target == TARGET(kUnk)) { - me->target = other.target; - } - if (me->precision == PRECISION(kUnk)) { - me->precision = other.precision; - } - if (me->layout == DATALAYOUT(kUnk)) { - me->layout = other.layout; - } - } - private: - // The default target for arguments, e.g. load weights to CPU memory for CUDA - // computation by default. + // The default target for arguments, e.g. load weights to CPU memory for + // CUDA computation by default. TargetType argument_default_target_{TARGET(kHost)}; }; diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 079586d5e0c00f261bfbf4c7658ccca97402f8ac..d94753220a1b5d963092c62c43d7e49b03243c63 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -99,7 +99,7 @@ class OpLite : public Registry { std::vector> CreateKernels( const std::vector &places, const std::string &kernel_type = ""); - lite::Scope *scope() { return scope_; } + Scope *scope() { return scope_; } // Assign op param to kernel. virtual void AttachKernel(KernelBase *kernel) = 0; @@ -169,7 +169,7 @@ class OpLite : public Registry { } protected: - lite::Scope *scope_{nullptr}; + Scope *scope_{nullptr}; std::unique_ptr kernel_; std::string op_type_; std::vector valid_places_; diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 70905c96f08d74fc5e27c85c7ccf3d395420a5e9..7645a117045ade89489c3769c4d75666bc3f8ae7 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -19,6 +19,7 @@ #include #include #include +#include "lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h" #include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/pass_manager.h" #include "lite/core/mir/pass_utils.h" @@ -36,6 +37,9 @@ namespace lite { * lite::Optimizer optimize a program. It utilize the mir passes to analysis the * program and export an optimized program. */ +// TODO(hong1986032) Support the following passes for the subblocks +const std::set kSubblockUnsupportedPasses( + {"memory_optimize_pass"}); class Optimizer { public: Optimizer() {} @@ -60,14 +64,20 @@ class Optimizer { program_ = &program; valid_places_ = valid_places; CHECK(!valid_places.empty()) << "At least one valid_place should be set"; - CHECK(!graph_) << "duplicate optimize found"; - - graph_.reset(new mir::SSAGraph); - graph_->Build(program, valid_places); - graph_->SetValidPlaces(valid_places); + CHECK(graphs_.empty()) << "duplicate optimize found"; + + auto block_size = program.block_size(); + for (size_t block_idx = 0; block_idx < block_size; ++block_idx) { + std::unique_ptr graph; + graph.reset(new mir::SSAGraph); + graph->Build(program, valid_places, block_idx); + graph->SetValidPlaces(valid_places); + graphs_.emplace_back(std::move(graph)); + } SpecifyKernelPickTactic(kernel_pick_factor); InitTargetTypeTransformPass(); + InitControlFlowOpUnusedInputsAndOutputsEliminatePass(); if (passes.empty() || passes.size() == 1) { std::vector passes_local{ @@ -76,6 +86,7 @@ class Optimizer { "lite_conv_elementwise_fuse_pass", // conv-elemwise-bn "lite_conv_bn_fuse_pass", // "lite_conv_elementwise_fuse_pass", // conv-bn-elemwise + "lite_conv_conv_fuse_pass", // // TODO(Superjomn) Refine the fusion related design to select fusion // kernels for devices automatically. "lite_conv_activation_fuse_pass", // @@ -111,6 +122,7 @@ class Optimizer { "apu_subgraph_pass", "rknpu_subgraph_pass", "mlu_subgraph_pass", + "control_flow_op_unused_inputs_and_outputs_eliminate_pass", "static_kernel_pick_pass", // pick original kernel from graph "remove_tf_redundant_ops_pass", @@ -175,62 +187,15 @@ class Optimizer { exec_scope_ = program.exec_scope(); } - const lite::Scope* exec_scope() const { return exec_scope_; } - - // Set shape(dims) infos of var descs to scope var. - // developer can write pass using input / output tensor dims of op. - // - // Example: If you have node `Node* softmax_node`, - // you can get dims of output tensor in passes: - // - // auto* scope = softmax_node->AsStmt().op()->scope(); - // auto softmax_out_arg_name = - // softmax_node->outlinks.front()->AsArg().name; - // auto softmax_out_tensor = - // scope->FindVar(softmax_out_arg_name)->Get(); - // softmax_out_dims = softmax_out_tensor.dims(); - void SetVarDescShapeToScopeVar() { - auto dims_to_str_func = [](std::vector shape) -> std::string { - std::string str_res; - for (size_t i = 0; i < shape.size(); ++i) { - str_res += std::to_string(shape[i]); - if (i != shape.size() - 1) { - str_res += "x"; - } - } - return str_res; - }; - - auto* program_desc = program_->program_desc(); - VLOG(5) << "program_desc->BlocksSize():" << program_desc->BlocksSize(); - auto blocks_desc = program_desc->GetBlocks(); - for (size_t bidx = 0; bidx < blocks_desc.size(); ++bidx) { - auto block_desc = blocks_desc[bidx]; - auto vars_desc = block_desc.GetVars(); - for (size_t vidx = 0; vidx < vars_desc.size(); ++vidx) { - auto var_desc = vars_desc[vidx]; - VLOG(5) << var_desc.Name() << " " - << dims_to_str_func(var_desc.GetShape()); - if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; - auto* var = program_->exec_scope()->FindVar(var_desc.Name()); - auto tensor = var->GetMutable(); - if (tensor->dims().size() == 0 && var_desc.GetShape().size() != 0) { - VLOG(5) << "var_desc.Name():" << var_desc.Name() - << " shape:" << dims_to_str_func(var_desc.GetShape()); - tensor->Resize(var_desc.GetShape()); - } - VLOG(5) << "var_desc.Name():" << var_desc.Name() - << " shape:" << dims_to_str_func(var_desc.GetShape()) - << " tensor:" << tensor->dims(); - } - } - } + const Scope* exec_scope() const { return exec_scope_; } // Generate a new program based on the mir graph. std::unique_ptr GenRuntimeProgram() { auto pass = mir::PassManager::Global().LookUp( "generate_program_pass"); - pass->Apply(graph_); + for (auto& graph : graphs_) { + pass->Apply(graph); + } auto program = pass->GenProgram(); CHECK(exec_scope_); program->set_exec_scope(exec_scope_); @@ -246,27 +211,38 @@ class Optimizer { pass->SetValidPlaces(valid_places_); } + void InitControlFlowOpUnusedInputsAndOutputsEliminatePass() { + auto* pass = + mir::PassManager::Global() + .LookUp( + "control_flow_op_unused_inputs_and_outputs_eliminate_pass"); + CHECK(pass); + CHECK(!graphs_.empty()); + pass->SetAllGraphs(&graphs_); + } + // Generate C++ code which combines the inference program, model and weights. void GenCode(const std::string& code_dir); - const mir::SSAGraph& ssa_graph() const { - CHECK(graph_); - return *graph_; + const mir::SSAGraph& ssa_graph(int block_idx = kRootBlockIdx) const { + CHECK(!graphs_.empty()); + CHECK(graphs_[block_idx]); + return *graphs_[block_idx]; } - mir::SSAGraph* mutable_ssa_graph() { - CHECK(graph_); - return graph_.get(); + mir::SSAGraph* mutable_ssa_graph(int block_idx = kRootBlockIdx) { + CHECK(!graphs_.empty()); + CHECK(graphs_[block_idx]); + return graphs_[block_idx].get(); } - lite::Scope* exec_scope() { return exec_scope_; } + Scope* exec_scope() { return exec_scope_; } protected: void SpecifyKernelPickTactic(core::KernelPickFactor factor); // Specify the passes and run them. void RunPasses(const std::vector& passes) { - SetVarDescShapeToScopeVar(); for (auto& x : passes) { LOG(INFO) << "== Running pass: " << x; mir::Pass* pass = mir::PassManager::Global().LookUp(x); @@ -284,16 +260,23 @@ class Optimizer { LOG(INFO) << " - Skip " << x << " because the target or kernel does not match."; } else { - pass->Apply(graph_); + // Check the pass whether it is supported for processing subblocks + if (kSubblockUnsupportedPasses.count(x)) { + pass->Apply(graphs_[kRootBlockIdx]); + } else { + for (auto& graph : graphs_) { + pass->Apply(graph); + } + } LOG(INFO) << "== Finished running: " << x; } } } private: - std::unique_ptr graph_; + std::vector> graphs_; std::vector valid_places_; - lite::Scope* exec_scope_{}; + Scope* exec_scope_{}; Program* program_{}; }; diff --git a/lite/core/program.cc b/lite/core/program.cc index 5aec6ee229d19ba164f10862619493253c21f541..289787d3bff897036e824dc7e37917b127ece2bc 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -15,6 +15,7 @@ #include "lite/core/program.h" #include #include +#include #include "lite/model_parser/cpp_desc.h" #include "lite/operators/conditional_block_op.h" #include "lite/operators/subgraph_op.h" @@ -26,121 +27,219 @@ namespace paddle { namespace lite { -void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) { - CHECK(desc); - // NOTE: RuntimeProgram do not has all meta info, so save model just update - // upon origin model - CHECK(desc->BlocksSize()); - auto main_block = desc->GetBlock(0); - main_block->ClearOps(); - for (auto& node : instructions_) { - auto op_type = node.op()->op_info()->Type(); - if (op_type == "subgraph") { - auto subgraph_op = const_cast( - static_cast(node.op())); - int sub_block_idx = subgraph_op->op_info()->GetAttr("sub_block"); - if (sub_block_idx < 0) { - // It's a new subgraph op when its sub_block_idx < 0, Now we add its +void RuntimeProgram::SaveToProgram( + std::shared_ptr program_desc) { + CHECK(program_desc); + auto block_size = program_desc->BlocksSize(); + CHECK_GT(block_size, 0) << "No block found!"; + // TODD(hong19860320) Only support updating the block desc which already + // exists in the origin program desc + CHECK_LE(block_size, instructions_.size()) + << "Invalid block size, expected (0," << instructions_.size() + << "] but got " << block_size; + for (size_t block_idx = 0; block_idx < block_size; ++block_idx) { + auto block_desc = program_desc->GetBlock(block_idx); + // Record all of the origin vars in the origin block + std::map origin_var_maps; + auto var_size = block_desc->VarsSize(); + for (size_t var_idx = 0; var_idx < var_size; ++var_idx) { + auto v = block_desc->GetVar(var_idx); + origin_var_maps.emplace(v->Name(), *v); + } + // Update the ops and vars for each block according to the instructions + block_desc->ClearVars(); + block_desc->ClearOps(); + std::set already_added_vars; + for (auto& inst : instructions_[block_idx]) { + auto* op = const_cast(inst.op()); + auto* op_info = op->op_info(); + auto op_type = op_info->Type(); + auto* kernel = inst.mutable_kernel(); + auto* scope = op->scope(); + // Update the origin vars which are referred by the instructions + // Add the new vars which are created in the passes and referred by the + // instructions + auto var_names = op_info->input_names(); + auto out_names = op_info->output_names(); + // Combine input and output vars and delete the duplicates + var_names.insert(var_names.end(), out_names.begin(), out_names.end()); + std::stable_sort(var_names.begin(), var_names.end()); + var_names.erase(std::unique(var_names.begin(), var_names.end()), + var_names.end()); + for (auto& var_name : var_names) { + if (already_added_vars.count(var_name)) continue; + auto* v = block_desc->AddVar(); + v->SetName(var_name); + auto it = origin_var_maps.find(var_name); + if (it != origin_var_maps.end()) { + v->SetType(it->second.GetType()); + v->SetPersistable(it->second.Persistable()); + if (var_name != "feed" && var_name != "fetch") { + v->SetShape(it->second.GetShape()); + v->SetDataType(it->second.GetDataType()); + } + } else { + std::string arg_name; + const Type* decl_type; + if (op_info->GetInputArgname(var_name, &arg_name)) { + decl_type = kernel->GetInputDeclType(arg_name); + } else { + op_info->GetOutputArgname(var_name, &arg_name); + decl_type = kernel->GetOutputDeclType(arg_name); + } + if (decl_type->IsTensor()) { + v->SetType(cpp::VarDesc::Type::LOD_TENSOR); + auto tensor = scope->FindVar(var_name)->GetMutable(); + v->SetPersistable(tensor->persistable()); + if (var_name != "feed" && var_name != "fetch") { + v->SetShape(tensor->dims().data()); + auto precision = tensor->precision(); + switch (precision) { +#define SET_DATATYPE(precision__, data_type) \ + case PrecisionType::precision__: \ + v->SetDataType(data_type); \ + LOG(INFO) << "Update var " << var_name << " done"; \ + break + SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL); + SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32); + SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16); + SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8); + SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16); + SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32); + SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64); +#undef SET_DATATYPE + default: + LOG(WARNING) << "Unknown precision type " + << PrecisionToStr(precision) << " for var " + << var_name << " in op " << op_type; + } + } + } else if (decl_type->IsTensorList()) { + // Set persistable=false for tensor array + v->SetType(cpp::VarDesc::Type::LOD_TENSOR_ARRAY); + v->SetPersistable(false); + } else { + CHECK(false) << "Unsupported decl type " << *decl_type + << " for var " << var_name << " in op " << op_type; + } + } + already_added_vars.insert(var_name); + } + // Replace all of origin ops with the instructions + auto op_desc = block_desc->AddOp(); + *op_desc = *op_info; + op_desc->SetAttr(kKernelTypeAttr, kernel->SerializedKernelType()); + if (op_type == "subgraph" && !op_info->GetAttr("sub_block")) { + // It's a new subgraph op when its sub_block_idx = 0, Now we add its // subblock desc to the program desc, Then update its sub_block_idx to // the index of block desc of the program desc. - sub_block_idx = desc->BlocksSize(); - auto sub_block_desc = subgraph_op->GetSubBlock(); - CHECK(sub_block_desc); - auto new_block_desc = desc->AddBlock(); - *new_block_desc = *sub_block_desc; - delete sub_block_desc; - subgraph_op->mutable_op_info()->SetAttr("sub_block", - sub_block_idx); - subgraph_op->SetSubBlock(new_block_desc); - // Update main block desc after a new subblock desc is added - main_block = desc->GetBlock(0); + auto subgraph_op = static_cast(op); + auto sub_program_desc = subgraph_op->GetProgramDesc(); + CHECK(sub_program_desc); + auto sub_block_desc = program_desc->AddBlock(); + *sub_block_desc = *sub_program_desc->GetBlock(0); + subgraph_op->SetProgramDesc(program_desc); + op_desc->SetAttr("sub_block", program_desc->BlocksSize() - 1); + // Attach op and kernel again to update the new block_idx and + // program_desc + subgraph_op->Attach(*op_desc, scope); + subgraph_op->AttachKernel(kernel); + // Update the pointer of block desc after a new subblock desc is added + block_desc = program_desc->GetBlock(block_idx); } } - auto op = main_block->AddOp(); - *op = *node.op()->op_info(); - op->SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); } } -// `UpdateVarsOfProgram` will remove unused var_descs and add new created -// vars' descs in the block 0. Now, the type of a new created var can only -// be LOD_TENSOR. -void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { - CHECK(desc); - CHECK(desc->BlocksSize()); - std::map origin_var_maps; - auto& main_block = *desc->GetBlock(0); - auto var_size = main_block.VarsSize(); - for (size_t i = 0; i < var_size; i++) { - auto v = main_block.GetVar(i); - auto name = v->Name(); - origin_var_maps.emplace(name, *v); - } - - main_block.ClearVars(); - for (auto& node : instructions_) { - auto* op = const_cast(node.op()); - auto* kernel = node.kernel(); - auto* scope = op->scope(); - auto in_names = op->op_info()->input_names(); - auto out_names = op->op_info()->output_names(); - in_names.insert(in_names.end(), out_names.begin(), out_names.end()); - std::stable_sort(in_names.begin(), in_names.end()); - in_names.erase(std::unique(in_names.begin(), in_names.end()), - in_names.end()); - for (auto& in_name : in_names) { - auto it = origin_var_maps.find(in_name); - if (it != origin_var_maps.end()) { - auto* v = main_block.AddVar(); - v->SetName((it->second).Name()); - v->SetType((it->second).GetType()); - v->SetPersistable((it->second).Persistable()); - if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") { - v->SetShape((it->second).GetShape()); - v->SetDataType((it->second).GetDataType()); - } +// Create runtime program from sub_block desc according to block_idx and +// program_desc, which is used for while/conditional_block/subgraph op. +RuntimeProgram::RuntimeProgram( + const std::shared_ptr& program_desc, + Scope* exec_scope, + int block_idx) + : exec_scope_(exec_scope) { +#ifdef LITE_WITH_OPENCL + using OpenCLContext = Context; + std::unique_ptr local_ctx(new KernelContext()); + local_ctx->As().InitOnce(); +#endif + CHECK(program_desc); + auto block_size = program_desc->BlocksSize(); + CHECK(block_size) << "No block found!"; + CHECK(block_idx >= 0 && block_idx < block_size) + << "Invalid block index, expected [0," << (block_size - 1) << "] but got " + << block_idx; + auto block_desc = program_desc->GetBlock(block_idx); + instructions_.resize(kRootBlockIdx + 1); + auto op_size = block_desc->OpsSize(); + for (size_t op_idx = 0; op_idx < op_size; op_idx++) { + auto op_desc = block_desc->GetOp(op_idx); + CHECK(op_desc); + std::string op_type = op_desc->Type(); + // if (op_type == "feed" || op_type == "fetch") continue; + // Create op and pick up the best kernel + auto op = LiteOpRegistry::Global().Create(op_type); + CHECK(op) << "no Op found for " << op_type; + if (op_type == "while") { + static_cast(op.get())->SetProgramDesc(program_desc); + } else if (op_type == "conditional_block") { + static_cast(op.get())->SetProgramDesc( + program_desc); + } else if (op_type == "subgraph") { + static_cast(op.get())->SetProgramDesc( + program_desc); + } + op->Attach(*op_desc, exec_scope_); + std::unique_ptr kernel; + if (op_desc->HasAttr(kKernelTypeAttr)) { + // Create op and pick up the best kernel according to the + // kKernelTypeAttr attribute + auto kernel_type = op_desc->GetAttr(kKernelTypeAttr); + std::string alias; + Place place; + KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); + VLOG(3) << "Found the attr '" << kKernelTypeAttr << "': " << kernel_type + << " for " << op_type; + auto kernels = op->CreateKernels({place}); + CHECK_GT(kernels.size(), 0) << "No kernels found for " << op_type; + auto it = std::find_if( + kernels.begin(), kernels.end(), [&](std::unique_ptr& it) { + return it->alias() == alias; + }); + CHECK(it != kernels.end()); + kernel = std::move(*it); + } else { + // TODO(hong19860320) add kernel picking according to the type of input + // and output tensors + VLOG(3) << "The attr '" << kKernelTypeAttr + << "' not found, pick the first kernel for " << op_type; + std::vector> kernels; +#if defined(LITE_WITH_ARM) + kernels = op->CreateKernels({Place{TARGET(kARM)}, Place{TARGET(kHost)}}); +#elif defined(LITE_WITH_X86) + kernels = op->CreateKernels({Place{TARGET(kX86)}, Place{TARGET(kHost)}}); +#endif + if (kernels.size() > 0) { + kernel = std::move(kernels.front()); } else { - // New created vars must be LOD_TENSOR - auto* v = main_block.AddVar(); - v->SetName(in_name); - v->SetType(cpp::VarDesc::Type::LOD_TENSOR); - std::string in_arg_name; - const Type* type; - if (op->op_info()->GetInputArgname(in_name, &in_arg_name)) { - type = kernel->GetInputDeclType(in_arg_name); - } else { - op->op_info()->GetOutputArgname(in_name, &in_arg_name); - type = kernel->GetOutputDeclType(in_arg_name); - } - if (type->IsTensor()) { - auto tensor = scope->FindVar(in_name)->GetMutable(); - v->SetPersistable(tensor->persistable()); - if (in_name != "feed" && in_name != "fetch") { - v->SetShape(tensor->dims().data()); - switch (tensor->precision()) { -#define SET_DATATYPE(precision__, data_type) \ - case PrecisionType::precision__: \ - v->SetDataType(data_type); \ - LOG(INFO) << "update var" << (it->second).Name() << "done"; \ - break - SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL); - SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32); - SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16); - SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8); - SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16); - SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32); - SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64); -#undef SET_DATATYPE - default: - VLOG(4) << "warning! unknown precision type"; - } - } - } else { - CHECK(false) << "unsupported var type"; - } + LOG(WARNING) << "No kernels found for " << op_type; } } +#ifdef LITE_WITH_OPENCL + if (kernel->target() == TARGET(kOpenCL)) { + std::unique_ptr ctx(new KernelContext()); + (*local_ctx).As().CopySharedTo(&ctx->As()); + kernel->SetContext(std::move(ctx)); + } else { + kernel->SetContext( + ContextScheduler::Global().NewContext(kernel->target())); + } +#else + kernel->SetContext(ContextScheduler::Global().NewContext(kernel->target())); +#endif + instructions_[kRootBlockIdx].emplace_back(std::move(op), std::move(kernel)); } + Init(); } #ifdef LITE_WITH_CUDA @@ -167,7 +266,8 @@ void RuntimeProgram::Run() { } #endif int idx = -1; - for (auto& inst : instructions_) { + auto& insts = instructions_[kRootBlockIdx]; + for (auto& inst : insts) { ++idx; #ifndef LITE_WITH_FPGA if (inst.is_feed_fetch_op()) continue; @@ -200,58 +300,50 @@ void RuntimeProgram::Run() { #endif } -void Program::Build(const cpp::ProgramDesc& prog) { +void Program::Build(const std::shared_ptr& program_desc) { CHECK(ops_.empty()) << "Executor duplicate Build found"; // Create operators. - auto& program = prog; - CHECK(program.BlocksSize()); - auto& main_block = *program.GetBlock(0); - for (size_t i = 0; i < main_block.OpsSize(); ++i) { - auto& op_desc = *main_block.GetOp(i); - auto op_type = op_desc.Type(); - // if (op_type == "feed" || op_type == "fetch") continue; - VLOG(4) << "create Op [" << op_type << "]"; - auto op = LiteOpRegistry::Global().Create(op_type); - CHECK(op) << "no Op found for " << op_type; - if (op_type == "while" || op_type == "conditional_block" || - op_type == "subgraph") { - auto sub_block_idx = op_desc.GetAttr("sub_block"); - CHECK(sub_block_idx >= 0 && - sub_block_idx < static_cast(program.BlocksSize())) - << "Invalid attribute sub_block(" << sub_block_idx << ") for " - << op_type; - auto sub_block_desc = - const_cast(prog).GetBlock( - sub_block_idx); - CHECK(sub_block_desc); + auto block_size = program_desc->BlocksSize(); + CHECK(block_size); + ops_.resize(block_size); + for (size_t block_idx = 0; block_idx < block_size; ++block_idx) { + auto* block_desc = program_desc->GetBlock(block_idx); + auto op_size = block_desc->OpsSize(); + for (size_t op_idx = 0; op_idx < op_size; ++op_idx) { + auto* op_desc = block_desc->GetOp(op_idx); + auto op_type = op_desc->Type(); + VLOG(4) << "create Op [" << op_type << "]"; + auto op = LiteOpRegistry::Global().Create(op_type); + CHECK(op) << "no Op found for " << op_type; if (op_type == "while") { - static_cast(op.get())->SetSubBlock( - sub_block_desc); + static_cast(op.get())->SetProgramDesc( + program_desc); } else if (op_type == "conditional_block") { - static_cast(op.get())->SetSubBlock( - sub_block_desc); + static_cast(op.get())->SetProgramDesc( + program_desc); } else if (op_type == "subgraph") { - static_cast(op.get())->SetSubBlock( - sub_block_desc); + static_cast(op.get())->SetProgramDesc( + program_desc); } + op->Attach(*op_desc, exec_scope_); + ops_[block_idx].emplace_back(std::move(op)); } - ops_.emplace_back(std::move(op)); - ops_.back()->Attach(op_desc, exec_scope_); } } -void Program::PrepareWorkspace(const cpp::ProgramDesc& prog, - const std::vector& var_names) { +void Program::PrepareWorkspace( + const std::shared_ptr& program_desc, + const std::vector& vars_to_clone) { CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found"; exec_scope_ = &scope_->NewScope(); // Create Feed and Fetch var. scope_->Var("feed")->GetMutable>(); scope_->Var("fetch")->GetMutable>(); - tmp_vars_.push_back("feed"); - tmp_vars_.push_back("fetch"); + vars_.push_back("feed"); + vars_.push_back("fetch"); - auto VarPrecision2KernlPrecision = + auto VarDescType2PrecisionType = [](const lite::VarDescAPI::Type& type) -> PrecisionType { switch (type) { case lite::VarDescAPI::Type::FP32: @@ -267,44 +359,60 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog, case lite::VarDescAPI::Type::INT64: return PRECISION(kInt64); default: - // LOG(FATAL) << "not supported type: " << static_cast(type); + LOG(WARNING) << "Unable to convert var desc type(" + << static_cast(type) << ") to precision type!"; return PRECISION(kUnk); } }; - auto& program = prog; - CHECK(program.BlocksSize()); - for (size_t b = 0; b < program.BlocksSize(); ++b) { - auto& main_block = *program.GetBlock(b); - for (size_t i = 0; i < main_block.VarsSize(); ++i) { - auto& var_desc = *main_block.GetVar(i); - if (!var_desc.Persistable()) { - if (var_desc.GetType() == lite::VarDescAPI::Type::LOD_TENSOR && - VarPrecision2KernlPrecision(var_desc.GetDataType()) != - PRECISION(kUnk)) { - var_data_type_[var_desc.Name()] = - VarPrecision2KernlPrecision(var_desc.GetDataType()); - } - tmp_vars_.push_back(var_desc.Name()); - VLOG(4) << "var name: " << var_desc.Name() << " type is " - << static_cast(var_desc.GetType()) << " data type is " - << static_cast(var_desc.GetDataType()); - exec_scope_->Var(var_desc.Name()); - if (b > 0) { - VLOG(4) << "var: " << var_desc.Name(); + auto block_size = program_desc->BlocksSize(); + CHECK(block_size); + for (size_t block_idx = 0; block_idx < block_size; ++block_idx) { + auto* block_desc = program_desc->GetBlock(block_idx); + auto var_size = block_desc->VarsSize(); + for (size_t var_idx = 0; var_idx < var_size; ++var_idx) { + auto* var_desc = block_desc->GetVar(var_idx); + const auto& var_name = var_desc->Name(); + const auto& var_type = var_desc->GetType(); + if (!var_desc->Persistable()) { + vars_.push_back(var_name); + auto* var = exec_scope_->Var(var_name); + VLOG(4) << "Var " << var_name << " in block " << block_idx; + VLOG(4) << " - type " << static_cast(var_type); + if (var_type == lite::VarDescAPI::Type::LOD_TENSOR) { + const auto& var_data_type = + VarDescType2PrecisionType(var_desc->GetDataType()); + if (var_data_type != PRECISION(kUnk)) { + var_type_map_[var_name] = LiteType::GetTensorTy( + TARGET(kUnk), var_data_type, DATALAYOUT(kUnk)); + } + VLOG(4) << " - data type " << static_cast(var_data_type); + // Create the tensor with the shape from var desc, it's convenient to + // the graph analysis in the passes, but you should resize the tensor + // with the real shape before accessing its data, because the + // var_shape may be [-1,3,224,224] + const auto& var_shape = var_desc->GetShape(); + auto* tensor = var->GetMutable(); + if (tensor->dims().empty() && !var_shape.empty()) { + tensor->Resize(var_shape); + VLOG(4) << " - dims " << tensor->dims().repr(); + } + } else if (var_type == lite::VarDescAPI::Type::LOD_TENSOR_ARRAY) { + var_type_map_[var_name] = LiteType::GetTensorListTy( + TARGET(kUnk), PRECISION(kUnk), DATALAYOUT(kUnk)); } } else { - if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; - weights_.push_back(var_desc.Name()); - if (var_desc.Persistable()) scope_->Var(var_desc.Name()); + if (var_name == "feed" || var_name == "fetch") continue; + weights_.push_back(var_name); + scope_->Var(var_name); } } } - for (auto i : var_names) { - exec_scope_->LocalVar(i); - auto* tensor = scope_->Var(i)->GetMutable(); - auto* sub_tensor = exec_scope_->Var(i)->GetMutable(); + for (auto var_name : vars_to_clone) { + exec_scope_->LocalVar(var_name); + auto* tensor = scope_->Var(var_name)->GetMutable(); + auto* sub_tensor = exec_scope_->Var(var_name)->GetMutable(); sub_tensor->CopyDataFrom(*tensor); } } diff --git a/lite/core/program.h b/lite/core/program.h index d47bb1c09e039af1abcd6c2d94b9aa9f3063f977..50c4bb37d64d188b721e8f46b2dcd01872e38704 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -41,61 +41,72 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__"; // - scope: which contains all the weights struct Program { public: - explicit Program(const std::shared_ptr& root) { scope_ = root; } - Program(const cpp::ProgramDesc& desc, - const std::shared_ptr& root, + explicit Program(const std::shared_ptr& root_scope) { + scope_ = root_scope; + } + Program(const std::shared_ptr& program_desc, + const std::shared_ptr& root_scope, const std::vector& valid_places, - const std::vector& var_names = {}) - : scope_(root), valid_places_(valid_places) { - desc_.CopyFrom(desc); + const std::vector& vars_to_clone = {}) + : scope_(root_scope), + valid_places_(valid_places), + program_desc_(program_desc) { CHECK(scope_) << "scope should be init first"; VLOG(4) << "prepare work"; - PrepareWorkspace(desc, var_names); + PrepareWorkspace(program_desc_, vars_to_clone); VLOG(4) << "build desc"; - Build(desc); + Build(program_desc_); VLOG(4) << "build desc finished"; } std::unique_ptr Clone() const { - std::unique_ptr res(new Program(desc_, scope_, valid_places_)); - return res; + return std::unique_ptr( + new Program(program_desc_, scope_, valid_places_)); } const std::list& weights() const { return weights_; } - const std::list& tmp_vars() const { return tmp_vars_; } + const std::list& vars() const { return vars_; } std::list* mutable_weights() { return &weights_; } - std::list* mutable_tmp_vars() { return &tmp_vars_; } + std::list* mutable_vars() { return &vars_; } - const std::list>& ops() const { return ops_; } - std::list>* mutable_ops() { return &ops_; } + const std::list>& ops( + int block_idx = kRootBlockIdx) const { + return ops_[block_idx]; + } + std::list>* mutable_ops( + int block_idx = kRootBlockIdx) { + return &ops_[block_idx]; + } - lite::Scope* exec_scope() { return exec_scope_; } - lite::Scope* scope() { return scope_.get(); } + size_t block_size() { return ops_.size(); } - cpp::ProgramDesc* program_desc() { return &desc_; } + Scope* exec_scope() { return exec_scope_; } + Scope* scope() { return scope_.get(); } - const std::map& var_data_type() const { - return var_data_type_; + cpp::ProgramDesc* program_desc() { return program_desc_.get(); } + + const std::map& var_type_map() const { + return var_type_map_; } private: // Build from a program and scope. - void Build(const cpp::ProgramDesc& program); + void Build(const std::shared_ptr& program_desc); // Create temporary variables. - void PrepareWorkspace(const cpp::ProgramDesc& program, - const std::vector& var_names = {}); + void PrepareWorkspace(const std::shared_ptr& program_desc, + const std::vector& vars_to_clone = {}); private: - std::map var_data_type_; - std::list tmp_vars_; + std::map var_type_map_; + std::list vars_; std::list weights_; - std::list> ops_; + std::vector>> ops_; // the scope to run the kernels, NOTE this is the execution scope. - std::shared_ptr scope_; + std::shared_ptr scope_; std::vector valid_places_; // Runtime scope. - lite::Scope* exec_scope_{}; - cpp::ProgramDesc desc_; + Scope* exec_scope_{}; + std::shared_ptr program_desc_; }; struct Instruction { @@ -179,8 +190,22 @@ struct Instruction { */ class LITE_API RuntimeProgram { public: - explicit RuntimeProgram(std::vector&& insts) + explicit RuntimeProgram(std::vector>&& insts) : instructions_(std::move(insts)) { + Init(); + } + explicit RuntimeProgram( + const std::shared_ptr& program_desc, + Scope* exec_scope, + int block_idx = kRootBlockIdx); + ~RuntimeProgram() { +#ifdef LITE_WITH_PROFILE + LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kCreate); + LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch); +#endif // LITE_WITH_PROFILE + } + + void Init() { if (instructions_.empty()) { LOG(FATAL) << "no instructions"; } @@ -189,7 +214,7 @@ class LITE_API RuntimeProgram { #endif #ifdef LITE_WITH_NVTX const NVTXAnnotator& annotator = NVTXAnnotator::Global(); - for (auto& inst : instructions_) { + for (auto& inst : instructions_[kRootBlockIdx]) { NVTXRangeAnnotation annotation = annotator.AnnotateBlock(); register_layer_names_.push_back(annotator.RegisterString( const_cast(inst.op())->Type().c_str())); @@ -197,30 +222,27 @@ class LITE_API RuntimeProgram { register_layer_names_.push_back(annotator.RegisterString("one_loop")); #endif } - ~RuntimeProgram() { -#ifdef LITE_WITH_PROFILE - LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kCreate); - LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch); -#endif // LITE_WITH_PROFILE - } void Run(); - void set_exec_scope(lite::Scope* x) { exec_scope_ = x; } - lite::Scope* exec_scope() { return exec_scope_; } + void set_exec_scope(Scope* x) { exec_scope_ = x; } + Scope* exec_scope() { return exec_scope_; } - size_t num_instructions() const { return instructions_.size(); } + const std::vector& instructions( + int block_idx = kRootBlockIdx) const { + return instructions_[block_idx]; + } - const std::vector& instructions() const { return instructions_; } + std::vector* mutable_instructions( + int block_idx = kRootBlockIdx) { + return &instructions_[block_idx]; + } - // `SaveOpInfosToProgram` will update the op list(ops_) of the block 0 - // in ProgramDesc. - void SaveOpInfosToProgram(cpp::ProgramDesc* desc); + size_t block_size() { return instructions_.size(); } - // `UpdateVarsOfProgram` will update the var list(vars_) of the block 0 in - // ProgramDesc. Namely, if a new var created in some passes, its var_desc will - // be added in vars_. - void UpdateVarsOfProgram(cpp::ProgramDesc* desc); + // Update the ops and vars of all of blocks to the given program_desc + // according to the instructions + void SaveToProgram(std::shared_ptr program_desc); #ifdef LITE_WITH_CUDA // UpdateCudaContext will update the exec stream and io stream of all kernels @@ -230,14 +252,14 @@ class LITE_API RuntimeProgram { private: RuntimeProgram(const RuntimeProgram&) = delete; - std::vector instructions_; - lite::Scope* exec_scope_{}; + std::vector> instructions_; + Scope* exec_scope_{}; #ifdef LITE_WITH_PROFILE profile::Profiler profiler_; void set_profiler() { - for (auto i = instructions_.begin(); i != instructions_.end(); ++i) { - i->set_profiler(&profiler_); + for (auto& inst : instructions_[kRootBlockIdx]) { + inst.set_profiler(&profiler_); } } #endif diff --git a/lite/kernels/apu/subgraph_compute.cc b/lite/kernels/apu/subgraph_compute.cc index 21373811dd91d009d834a16d2c437bc722cd676a..579ed97b161dade9822250dab411cefd214b50f8 100644 --- a/lite/kernels/apu/subgraph_compute.cc +++ b/lite/kernels/apu/subgraph_compute.cc @@ -37,7 +37,7 @@ bool SubgraphEngine::BuildDeviceProgram() { subgraph::apu::Graph graph; int neuron_errCode = NeuronModel_create(&model_); if (NEURON_NO_ERROR != neuron_errCode) { - LOG(WARNING) << "Fail to create model"; + LOG(WARNING) << "[APU] Failed to create the neuron model!"; return false; } graph.set_model(model_); @@ -46,11 +46,12 @@ bool SubgraphEngine::BuildDeviceProgram() { // Convert all of ops and their input vars and weights and added into the APU // NIR graph - if (origin_program_.empty()) { + if (!origin_program_) { BuildOriginProgram(); } const auto& bridges = subgraph::Registry::Instance(); - for (auto& inst : origin_program_) { + const auto& insts = origin_program_->instructions(kRootBlockIdx); + for (auto& inst : insts) { auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); @@ -70,55 +71,38 @@ bool SubgraphEngine::BuildDeviceProgram() { } } - // Get input tensor - std::vector ins; - origin_itensors_.resize(input_names_.size()); - origin_idims_.resize(input_names_.size()); + // Get the index of input tensors + std::vector input_indices; for (int i = 0; i < input_names_.size(); i++) { - origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); - CHECK(origin_itensors_[i]); - origin_idims_[i] = origin_itensors_[i]->dims(); - VLOG(3) << "subgraph input name: " << i << ", " << input_names_[i] << ":" - << origin_idims_[i].production(); - // Get input index - int idx; - if (graph.Has(input_names_[i])) { - ins.push_back(graph.Get(input_names_[i])->index()); - VLOG(3) << "input idx: " << graph.Get(input_names_[i])->index(); - } else { - LOG(WARNING) << "Fail to find input: " << input_names_[i]; - return false; - } + CHECK(graph.Has(input_names_[i])) << "[APU] Failed to find input node " + << input_names_[i]; + auto index = graph.Get(input_names_[i])->index(); + input_indices.push_back(index); + VLOG(3) << "[APU] Input[" << i << "] name " << input_names_[i] << " dims " + << origin_itensors_[i]->dims() << " index " << index; } - // Get output tensor - std::vector outs; - origin_otensors_.resize(output_names_.size()); - origin_odims_.resize(output_names_.size()); + // Get the index of output tensors + std::vector output_indices; for (int i = 0; i < output_names_.size(); i++) { - origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); - CHECK(origin_otensors_[i]); - origin_odims_[i] = origin_otensors_[i]->dims(); - VLOG(3) << "subgraph output name: " << i << ", " << output_names_[i] << ":" - << origin_odims_[i].production(); + CHECK(graph.Has(output_names_[i])) << "[APU] Failed to find output node " + << output_names_[i]; origin_otensors_[i]->mutable_data(); - // Get input index - if (graph.Has(output_names_[i])) { - outs.push_back(graph.Get(output_names_[i])->index()); - VLOG(3) << "output idx: " << graph.Get(output_names_[i])->index(); - } else { - LOG(WARNING) << "Fail to find output: " << output_names_[i]; - return false; - } + auto index = graph.Get(output_names_[i])->index(); + output_indices.push_back(index); + VLOG(3) << "[APU] Output[" << i << "] name " << output_names_[i] << " dims " + << origin_otensors_[i]->dims() << " index " << index; } - VLOG(3) << "ins size: " << ins.size() << " outs size:" << outs.size(); - // Set subgraph input/output - NeuronModel_identifyInputsAndOutputs( - model_, ins.size(), &ins[0], outs.size(), &outs[0]); + // Indentify the input and output tensors of the neuron model + NeuronModel_identifyInputsAndOutputs(model_, + input_indices.size(), + &input_indices[0], + output_indices.size(), + &output_indices[0]); neuron_errCode = NeuronModel_finish(model_); if (NEURON_NO_ERROR != neuron_errCode) { - LOG(WARNING) << "Fail to create NIR model:" << neuron_errCode; + LOG(WARNING) << "[APU] Fail to create NIR model:" << neuron_errCode; return false; } VLOG(3) << "[APU] APU NIR model created!"; @@ -207,11 +191,11 @@ SubgraphEngine::~SubgraphEngine() { void SubgraphCompute::PrepareForRun() { auto& param = this->Param(); engine_.reset(new SubgraphEngine(ctx_.get(), - param.sub_block_idx, - param.sub_block_desc, + param.block_idx, + param.program_desc, + param.exec_scope, param.input_data_names, - param.output_data_names, - param.scope)); + param.output_data_names)); CHECK(engine_); } diff --git a/lite/kernels/apu/subgraph_compute.h b/lite/kernels/apu/subgraph_compute.h index beb582b8cc16e456491c28ace5e2d1695143216a..de15abdf7fdbce8001676a2bf7f651ad1e435c74 100644 --- a/lite/kernels/apu/subgraph_compute.h +++ b/lite/kernels/apu/subgraph_compute.h @@ -31,12 +31,16 @@ class SubgraphEngine : public subgraph::Engine { public: SubgraphEngine(KernelContext *ctx, int block_idx, - cpp::BlockDesc *block_desc, + const std::shared_ptr &program_desc, + Scope *exec_scope, const std::vector &input_names, - const std::vector &output_names, - Scope *scope) - : subgraph::Engine( - ctx, block_idx, block_desc, input_names, output_names, scope) {} + const std::vector &output_names) + : subgraph::Engine(ctx, + block_idx, + program_desc, + exec_scope, + input_names, + output_names) {} ~SubgraphEngine(); diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 6d1d24adcb4cf74b3c6bb991a33316e974dc0110..f4fe6ba1ebb9a7e775f0d5db1031f9fd40508c20 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -75,7 +75,6 @@ add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_comp add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -87,7 +86,6 @@ add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_comp add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/conditional_block_compute.h b/lite/kernels/arm/conditional_block_compute.h deleted file mode 100644 index 91eadff931ec8aa54092347bcf18f8428130ef75..0000000000000000000000000000000000000000 --- a/lite/kernels/arm/conditional_block_compute.h +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include -#include "lite/core/kernel.h" -#include "lite/core/op_registry.h" -#include "lite/core/program.h" -#include "lite/operators/conditional_block_op.h" -#ifdef LITE_WITH_PROFILE -#include "lite/core/profile/basic_profiler.h" -#include "lite/core/profile/precision_profiler.h" -#include "lite/core/profile/profiler.h" -#endif - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -class CondExecutor { - typedef std::shared_ptr OpPtr; - - public: - CondExecutor(cpp::BlockDesc *block, Scope *scope, Place place) - : scope_(scope), place_(place) { - int32_t op_size = block->OpsSize(); - for (int32_t i = 0; i < op_size; ++i) { - auto &op_desc = *block->template GetOp(i); - auto op_type = op_desc.Type(); - auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type()); - op_handler->Attach(op_desc, scope); - - auto hostplace = place_; - hostplace.target = TARGET(kHost); - auto kernels = op_handler->CreateKernels({place_, hostplace}); - CHECK_GT(kernels.size(), 0) << "cannot create kernel"; - op_handler->AttachKernel(kernels[0].get()); - op_handler->SetKernel(kernels); - ops_of_block_.push_back(op_handler); - } - } - - void Run() { -#ifdef LITE_WITH_PROFILE -#ifdef LITE_WITH_PRECISION_PROFILE - lite::profile::Profiler profiler; -#endif // LITE_WITH_PRECISION_PROFILE -#endif // LITE_WITH_PROFILE - for (auto &op_handler : ops_of_block_) { - op_handler->CheckShape(); - op_handler->InferShape(); -#ifdef LITE_WITH_PROFILE -#ifdef LITE_WITH_PRECISION_PROFILE - std::unique_ptr kernel(op_handler->GetKernel()); - Instruction inst(op_handler, std::move(kernel)); - inst.set_profiler(&profiler); -#endif // LITE_WITH_PRECISION_PROFILE -#endif // LITE_WITH_PROFILE - op_handler->Run(); -#ifdef LITE_WITH_PROFILE -#ifdef LITE_WITH_PRECISION_PROFILE - LITE_PRECISION_PROFILE(inst) -#endif // LITE_WITH_PRECISION_PROFILE -#endif // LITE_WITH_PROFILE - } - } - - private: - Scope *scope_; - Place place_; - std::vector ops_of_block_; -}; - -class ConditionalBlockCompute - : public KernelLite { - public: - using param_t = operators::ConditionalBlockParam; - - void PrepareForRun() override; - void Run() override; - - virtual ~ConditionalBlockCompute() = default; - - private: - std::shared_ptr executor_; -}; - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 28082785e1c726097a8bfd2165f0d09b9962a5e7..3e898d9ded2153588c164d2ccd618fc77f7c3854 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -202,17 +202,13 @@ void ElementwiseMulCompute::Run() { } } -template <> -void ElementwiseMulCompute::Run() { - auto& param = this->template Param(); - lite::arm::math::elementwise_compute_basic(param, "mul", ""); -} - -void ElementwiseMulActivationCompute::Run() { - auto& param = Param(); - const float* x_data = param.X->data(); - const float* y_data = param.Y->data(); - float* out_data = param.Out->mutable_data(); +template +void ElementwiseMulActivationCompute::Run() { + auto& param = + this->template Param(); + auto* x_data = param.X->template data(); + auto* y_data = param.Y->template data(); + auto* out_data = param.Out->template mutable_data(); int axis = param.axis; std::string act_type = param.act_type; auto x_dims = param.X->dims(); @@ -221,21 +217,21 @@ void ElementwiseMulActivationCompute::Run() { if (x_dims.size() < y_dims.size() && is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { if (act_type == "relu") { - lite::arm::math::elementwise_mul_relu_broadcast( + lite::arm::math::elementwise_mul_relu_broadcast( y_data, x_data, out_data, pre, n, post); } else { LOG(FATAL) << "unsupported Activation type: " << act_type; } } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (act_type == "relu") { - lite::arm::math::elementwise_mul_relu_broadcast( + lite::arm::math::elementwise_mul_relu_broadcast( x_data, y_data, out_data, pre, n, post); } else { LOG(FATAL) << "unsupported Activation type: " << act_type; } } else { if (act_type == "relu") { - lite::arm::math::elementwise_mul_relu( + lite::arm::math::elementwise_mul_relu( x_data, y_data, out_data, x_dims.production()); } else { LOG(FATAL) << "unsupported Activation type: " << act_type; @@ -426,46 +422,60 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -using elementwise_mul_float = +using elementwise_mul_float_t = paddle::lite::kernels::arm::ElementwiseMulCompute; REGISTER_LITE_KERNEL( - elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float, def) + elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -using elementwise_mul_int32 = +using elementwise_mul_int32_t = paddle::lite::kernels::arm::ElementwiseMulCompute; REGISTER_LITE_KERNEL( - elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32, def) + elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .Finalize(); -using elementwise_mul_int64 = +using elementwise_mul_int64_t = paddle::lite::kernels::arm::ElementwiseMulCompute; REGISTER_LITE_KERNEL( - elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64, def) + elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .Finalize(); -REGISTER_LITE_KERNEL( - fusion_elementwise_mul_activation, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ElementwiseMulActivationCompute, - def) +using fusion_elementwise_mul_activation_float_t = paddle::lite::kernels::arm:: + ElementwiseMulActivationCompute; +REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation, + kARM, + kFloat, + kNCHW, + fusion_elementwise_mul_activation_float_t, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +using fusion_elementwise_mul_activation_int64_t = paddle::lite::kernels::arm:: + ElementwiseMulActivationCompute; +REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation, + kARM, + kInt64, + kNCHW, + fusion_elementwise_mul_activation_int64_t, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_max, kARM, kFloat, @@ -489,22 +499,22 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -using elementwise_div_fp32 = +using elementwise_div_fp32_t = paddle::lite::kernels::arm::ElementwiseDivCompute; REGISTER_LITE_KERNEL( - elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32, def) + elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -using elementwise_div_int64 = +using elementwise_div_int64_t = paddle::lite::kernels::arm::ElementwiseDivCompute; REGISTER_LITE_KERNEL( - elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64, def) + elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) @@ -522,11 +532,11 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -using elementwise_mod_int64 = +using elementwise_mod_int64_t = paddle::lite::kernels::arm::ElementwiseModCompute; REGISTER_LITE_KERNEL( - elementwise_mod, kARM, kInt64, kNCHW, elementwise_mod_int64, def) + elementwise_mod, kARM, kInt64, kNCHW, elementwise_mod_int64_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index 7d7a93bf6954de9bbcd1b44061e614cd041fafe8..89d9898648d25fec98568f2456fe96903da0a69d 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -62,8 +62,8 @@ class ElementwiseMulCompute : public KernelLite { virtual ~ElementwiseMulCompute() = default; }; -class ElementwiseMulActivationCompute - : public KernelLite { +template +class ElementwiseMulActivationCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/arm/elementwise_compute_test.cc b/lite/kernels/arm/elementwise_compute_test.cc index 62a5bc423ca6e72098332963713e8baffb366325..79262fb4ef75283eba12efa0a4ad8dc048681338 100644 --- a/lite/kernels/arm/elementwise_compute_test.cc +++ b/lite/kernels/arm/elementwise_compute_test.cc @@ -533,13 +533,15 @@ TEST(fusion_elementwise_mul_activation_arm, retrive_op) { } TEST(fusion_elementwise_mul_activation_arm, init) { - ElementwiseMulActivationCompute fusion_elementwise_mul_activation; + ElementwiseMulActivationCompute + fusion_elementwise_mul_activation; ASSERT_EQ(fusion_elementwise_mul_activation.precision(), PRECISION(kFloat)); ASSERT_EQ(fusion_elementwise_mul_activation.target(), TARGET(kARM)); } TEST(fusion_elementwise_mul_activation_arm, compute) { - ElementwiseMulActivationCompute fusion_elementwise_mul_activation; + ElementwiseMulActivationCompute + fusion_elementwise_mul_activation; operators::FusionElementwiseActivationParam param; lite::Tensor x, y, output, output_ref; diff --git a/lite/kernels/arm/gather_compute.cc b/lite/kernels/arm/gather_compute.cc index 2a9c70aede7475b36f70c628ff6ccaa823f030b2..f5a87e5431955252e47143252ce13ba4056c4a7f 100644 --- a/lite/kernels/arm/gather_compute.cc +++ b/lite/kernels/arm/gather_compute.cc @@ -20,44 +20,45 @@ namespace lite { namespace kernels { namespace arm { -template +template void GatherFunc(const operators::GatherParam& param) { auto src_dims = param.X->dims(); auto index_size = param.Index->dims()[0]; - auto* p_src = param.X->data(); - const int* p_index = param.Index->data(); - auto* p_output = param.Out->mutable_data(); + auto* p_src = param.X->data(); + const IndexType* p_index = param.Index->data(); + auto* p_output = param.Out->mutable_data(); int slice_size = 1; for (size_t i = 1; i < src_dims.size(); ++i) { slice_size *= src_dims[i]; } for (int i = 0; i < index_size; ++i) { - int index_ = p_index[i]; + IndexType index_ = p_index[i]; memcpy(p_output + i * slice_size, p_src + index_ * slice_size, - slice_size * sizeof(T)); + slice_size * sizeof(DataType)); } } -void GatherCompute::Run() { - auto& param = this->Param(); +template +void GatherCompute::Run() { + auto& param = this->template Param(); switch (param.X->precision()) { case PRECISION(kFloat): - GatherFunc(param); + GatherFunc(param); break; case PRECISION(kInt8): - GatherFunc(param); + GatherFunc(param); break; case PRECISION(kInt16): - GatherFunc(param); + GatherFunc(param); break; case PRECISION(kInt32): - GatherFunc(param); + GatherFunc(param); break; case PRECISION(kInt64): - GatherFunc(param); + GatherFunc(param); break; default: LOG(FATAL) << "Gather does not implement for the " @@ -70,9 +71,26 @@ void GatherCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - gather, kARM, kAny, kNCHW, paddle::lite::kernels::arm::GatherCompute, def) +REGISTER_LITE_KERNEL(gather, + kARM, + kAny, + kNCHW, + paddle::lite::kernels::arm::GatherCompute, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .BindInput("Index", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .Finalize(); + +REGISTER_LITE_KERNEL(gather, + kARM, + kAny, + kNCHW, + paddle::lite::kernels::arm::GatherCompute, + def_int64_idx) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .BindInput("Index", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize(); diff --git a/lite/kernels/arm/gather_compute.h b/lite/kernels/arm/gather_compute.h index 9753f42972407b250886afa6bada8861a642e189..0226e5f68eee3f23dbd945af6f4f455ab79190c5 100644 --- a/lite/kernels/arm/gather_compute.h +++ b/lite/kernels/arm/gather_compute.h @@ -23,6 +23,7 @@ namespace lite { namespace kernels { namespace arm { +template class GatherCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/arm/sequence_conv_compute.cc b/lite/kernels/arm/sequence_conv_compute.cc index 71826ae6732b1a0b829a08e72d4f46cc47832ae4..455615e66de53a4a6f235f8ab803394962292936 100644 --- a/lite/kernels/arm/sequence_conv_compute.cc +++ b/lite/kernels/arm/sequence_conv_compute.cc @@ -102,10 +102,14 @@ void SequenceConvCompute::Run() { 1, 1, // stride_h, stride_w, dilation_h, dilation_w tmp_data); - local_naive_transpose(tmp_data, - sub_col_data, - kernel_size * hidden_dim, - input_row_end - input_row_begin); + int cols = kernel_size * hidden_dim; + int rows = input_row_end - input_row_begin; + if (cols % 4 == 0 && rows % 4 == 0) { + paddle::lite::arm::math::local_transpose( + tmp_data, sub_col_data, cols, rows); + } else { + local_naive_transpose(tmp_data, sub_col_data, cols, rows); + } } } diff --git a/lite/kernels/arm/while_compute.h b/lite/kernels/arm/while_compute.h deleted file mode 100644 index f735d96f9190755daacdf846a2d99901c1a14493..0000000000000000000000000000000000000000 --- a/lite/kernels/arm/while_compute.h +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include "lite/core/kernel.h" -#include "lite/core/op_registry.h" -#include "lite/operators/while_op.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -class StepExecutor { - typedef std::shared_ptr OpPtr; - - public: - StepExecutor(cpp::BlockDesc *block, Scope *scope, Place place) - : scope_(scope), place_(place) { - int32_t op_size = block->OpsSize(); - for (int32_t i = 0; i < op_size; ++i) { - auto &op_desc = *block->template GetOp(i); - auto op_type = op_desc.Type(); - auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type()); - // VLOG(4) << "while: creating Op [" << op_type << "]"; - op_handler->Attach(op_desc, scope); - - auto hostplace = place_; - hostplace.target = TARGET(kHost); - auto kernels = op_handler->CreateKernels({place_, hostplace}); - CHECK_GT(kernels.size(), 0) << "cannot create kernel"; - op_handler->AttachKernel(kernels[0].get()); - op_handler->SetKernel(kernels); - ops_of_block_.push_back(op_handler); - } - } - - void Run() { - for (auto &op_handler : ops_of_block_) { - // VLOG(4) << op_handler->op_info()->Repr(); - op_handler->InferShape(); - // VLOG(4) << "while: infered shape"; - op_handler->Run(); - } - } - - private: - Scope *scope_; - Place place_; - std::vector ops_of_block_; -}; - -class WhileCompute : public KernelLite { - public: - using param_t = operators::WhileParam; - - void Run() override; - void PrepareForRun() override; - - virtual ~WhileCompute() = default; - - private: - std::shared_ptr executor_; -}; - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle diff --git a/lite/kernels/bm/subgraph_compute.cc b/lite/kernels/bm/subgraph_compute.cc index 664198cf9fb45664fdc088df382b9b94a1924e9b..ea0dd82325976f33f123f21e0eb4aeb5dfdbfa9d 100644 --- a/lite/kernels/bm/subgraph_compute.cc +++ b/lite/kernels/bm/subgraph_compute.cc @@ -28,36 +28,17 @@ namespace lite { namespace kernels { namespace bm { -bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() { - // Obtain the origin input tensors, and create the origin output - // tensors(Don't try to access them before launch the device program or the - // origin program) - PrepareWorkspaceForOriginProgram(); - // Create the device input and output tensors, but don't initialize them - // with the dimensions - device_inputs_.resize(input_names_.size()); - for (int i = 0; i < input_names_.size(); i++) { - device_inputs_[i].reset(new hiai::AiTensor); - CHECK(device_inputs_[i]); - } - device_outputs_.resize(output_names_.size()); - for (int i = 0; i < output_names_.size(); i++) { - device_outputs_[i].reset(new hiai::AiTensor); - CHECK(device_outputs_[i]); - } - return true; -} - bool SubgraphEngine::BuildDeviceProgram() { int status = 0; subgraph::bm::Graph graph; const auto& bridges = subgraph::Registry::Instance(); graph.CreateCompilerHandle(); auto& ctx = this->ctx_->template As(); - if (origin_program_.empty()) { + if (!origin_program_) { BuildOriginProgram(); } - for (auto& inst : origin_program_) { + const auto& insts = origin_program_->instructions(kRootBlockIdx); + for (auto& inst : insts) { auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); @@ -93,13 +74,11 @@ bool SubgraphEngine::BuildDeviceProgram() { net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]); auto& stage = net_info_->stages[0]; // input - origin_idims_.resize(input_names_.size()); - origin_itensors_.resize(input_names_.size()); device_inputs_.resize(input_names_.size()); for (size_t i = 0; i < input_names_.size(); i++) { - origin_itensors_[i] = scope_->FindMutableTensor(net_info_->input_names[i]); + origin_itensors_[i] = + exec_scope_->FindMutableTensor(net_info_->input_names[i]); CHECK(origin_itensors_[i]); - origin_idims_[i] = origin_itensors_[i]->dims(); bm_device_mem_t* p_mem = static_cast(malloc(sizeof(bm_device_mem_t))); CHECK(p_mem != nullptr); @@ -112,8 +91,6 @@ bool SubgraphEngine::BuildDeviceProgram() { stage.input_shapes[i]); } // output - origin_odims_.resize(output_names_.size()); - origin_otensors_.resize(output_names_.size()); device_outputs_.resize(net_info_->output_num); int out_index = 0; for (int i = 0; i < output_names_.size(); i++) { @@ -121,14 +98,13 @@ bool SubgraphEngine::BuildDeviceProgram() { } for (int i = 0; i < net_info_->output_num; i++) { - Tensor* t_cur = scope_->FindMutableTensor(net_info_->output_names[i]); + Tensor* t_cur = exec_scope_->FindMutableTensor(net_info_->output_names[i]); CHECK(t_cur != nullptr); bm_device_mem_t* p_mem = static_cast(malloc(sizeof(bm_device_mem_t))); CHECK(p_mem != nullptr); if (outname_map_.find(net_info_->output_names[i]) != outname_map_.end()) { origin_otensors_[out_index] = t_cur; - origin_odims_[out_index] = origin_otensors_[out_index]->dims(); origin_otensors_[out_index]->mutable_data(); out_index += 1; } @@ -173,11 +149,11 @@ bool SubgraphEngine::LaunchDeviceProgram() { void SubgraphCompute::PrepareForRun() { auto& param = this->Param(); engine_.reset(new SubgraphEngine(ctx_.get(), - param.sub_block_idx, - param.sub_block_desc, + param.block_idx, + param.program_desc, + param.exec_scope, param.input_data_names, - param.output_data_names, - param.scope)); + param.output_data_names)); CHECK(engine_); } diff --git a/lite/kernels/bm/subgraph_compute.h b/lite/kernels/bm/subgraph_compute.h index 7a5b2552ff95681da09346ba11f40f1a6acb7f01..d1dcb3a6d3ef7eb6d9091eb45d1960862cca273a 100644 --- a/lite/kernels/bm/subgraph_compute.h +++ b/lite/kernels/bm/subgraph_compute.h @@ -36,15 +36,18 @@ class SubgraphEngine : public subgraph::Engine { public: SubgraphEngine(KernelContext *ctx, int block_idx, - cpp::BlockDesc *block_desc, + const std::shared_ptr &program_desc, + Scope *exec_scope, const std::vector &input_names, - const std::vector &output_names, - Scope *scope) - : subgraph::Engine( - ctx, block_idx, block_desc, input_names, output_names, scope) {} + const std::vector &output_names) + : subgraph::Engine(ctx, + block_idx, + program_desc, + exec_scope, + input_names, + output_names) {} protected: - bool PrepareWorkspaceForDeviceProgram() override; bool BuildDeviceProgram() override; bool LaunchDeviceProgram() override; diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index cd91d2dc90f9f48668e1d5ab9fbe5d065cb0e191..381b9304142537da028b35c688128d34465965aa 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -18,6 +18,9 @@ add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute. add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps}) add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps}) add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program) +add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program) add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps}) if(LITE_BUILD_EXTRA) diff --git a/lite/kernels/host/assign_compute.cc b/lite/kernels/host/assign_compute.cc index e496ffbd1d9a6362d730117be949cbdab83ec62a..bfbbc32e5f3b3b4dd5936e0e296306641312cabf 100644 --- a/lite/kernels/host/assign_compute.cc +++ b/lite/kernels/host/assign_compute.cc @@ -51,3 +51,19 @@ REGISTER_LITE_KERNEL( PRECISION(kAny), DATALAYOUT(kAny))}) .Finalize(); + +REGISTER_LITE_KERNEL(assign, + kHost, + kAny, + kAny, + paddle::lite::kernels::host::AssignCompute, + def_tensor_array) + .BindInput("X", + {LiteType::GetTensorListTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorListTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .Finalize(); diff --git a/lite/kernels/arm/conditional_block_compute.cc b/lite/kernels/host/conditional_block_compute.cc similarity index 51% rename from lite/kernels/arm/conditional_block_compute.cc rename to lite/kernels/host/conditional_block_compute.cc index f0bd43e1300d4034241c03d3e4ce27dcaa59c1e5..5bdca012dd4e838f3371bae7cf17634513d59db5 100644 --- a/lite/kernels/arm/conditional_block_compute.cc +++ b/lite/kernels/host/conditional_block_compute.cc @@ -12,28 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/conditional_block_compute.h" -#include -#include -#include -#include "lite/backends/arm/math/funcs.h" -#include "lite/core/tensor.h" -#include "lite/core/type_system.h" +#include "lite/kernels/host/conditional_block_compute.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { void ConditionalBlockCompute::PrepareForRun() { - auto& param = Param(); - auto cur_scope = param.scope; - - executor_ = - std::make_shared(param.sub_block, cur_scope, place()); + auto& param = this->Param(); + program_.reset(new RuntimeProgram( + param.program_desc, param.exec_scope, param.block_idx)); } + void ConditionalBlockCompute::Run() { - auto& param = Param(); + auto& param = this->Param(); for (auto& out : param.outs) { out->clear(); } @@ -43,32 +36,40 @@ void ConditionalBlockCompute::Run() { auto* cond_data = cond->data(); need_run = cond_data[0]; } else { - auto x = param.x; - for (auto pt : x) { - if (pt == nullptr || !pt->IsInitialized() || pt->dims().empty()) { + for (auto input : param.inputs) { + if (input == nullptr || !input->IsInitialized() || + input->dims().empty()) { need_run = false; break; } } } if (need_run) { - executor_->Run(); + program_->Run(); } } -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle REGISTER_LITE_KERNEL(conditional_block, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ConditionalBlockCompute, + kHost, + kAny, + kAny, + paddle::lite::kernels::host::ConditionalBlockCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Cond", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Scope", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Input", + {LiteType::GetTensorListTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("Cond", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorListTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Scope", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .Finalize(); diff --git a/lite/kernels/host/conditional_block_compute.h b/lite/kernels/host/conditional_block_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..8d3381ce3c4d6da076e6bb477df423bc640c56c9 --- /dev/null +++ b/lite/kernels/host/conditional_block_compute.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/program.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class ConditionalBlockCompute + : public KernelLite { + public: + using param_t = operators::ConditionalBlockParam; + + void PrepareForRun() override; + void Run() override; + + private: + std::unique_ptr program_; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/host/print_compute.cc b/lite/kernels/host/print_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..00c8ab7b13597ad33b9fafc878cd553572462a99 --- /dev/null +++ b/lite/kernels/host/print_compute.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/host/print_compute.h" + +#include // NOLINT +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +const char kForward[] = "FORWARD"; +const char kBackward[] = "BACKWARD"; +const char kBoth[] = "BOTH"; + +class TensorFormatter { + public: + TensorFormatter() {} + + std::string Format(const Tensor& print_tensor, + const std::string& tensor_name = "", + const std::string& message = "") { + std::stringstream log_stream; + if (!tensor_name.empty()) { + log_stream << "Variable: " << tensor_name << std::endl; + } + + if (!message.empty()) { + log_stream << " - message: " << message << std::endl; + } + + if (print_tensor_lod_) { + log_stream << " - lod: {"; + const LoD& lod = print_tensor.lod(); + for (auto level : lod) { + log_stream << "{"; + bool is_first = true; + for (auto i : level) { + if (is_first) { + log_stream << i; + is_first = false; + } else { + log_stream << ", " << i; + } + } + log_stream << "}"; + } + log_stream << "}" << std::endl; + } + + log_stream << " - place: " << TargetToStr(print_tensor.target()) + << std::endl; // TODO(hong19860320) always kHost + + if (print_tensor_shape_) { + log_stream << " - shape: " << print_tensor.dims().repr() << std::endl; + } + + if (print_tensor_layout_) { + log_stream << " - layout: " + << DataLayoutToStr( + DATALAYOUT(kNCHW)) // TODO(hong19860320) Query the data + // layout from target tensor + << std::endl; + } + + auto dtype = print_tensor.precision(); + if (print_tensor_type_) { + log_stream << " - dtype: " << PrecisionToStr(dtype) << std::endl; + } + + if (dtype == PRECISION(kBool)) { + FormatData(print_tensor, log_stream); + } else if (dtype == PRECISION(kInt8)) { + FormatData(print_tensor, log_stream); + } else if (dtype == PRECISION(kInt16)) { + FormatData(print_tensor, log_stream); + } else if (dtype == PRECISION(kInt32)) { + FormatData(print_tensor, log_stream); + } else if (dtype == PRECISION(kInt64)) { + FormatData(print_tensor, log_stream); + } else if (dtype == PRECISION(kFloat)) { + FormatData(print_tensor, log_stream); + } else { + log_stream << "\tdata: unprintable type: " << PrecisionToStr(dtype) + << std::endl; + } + return log_stream.str(); + } + + void Print(const Tensor& print_tensor, + const std::string& tensor_name = "", + const std::string& message = "") { + static std::mutex mutex; + std::lock_guard lock(mutex); + std::cout << Format(print_tensor, tensor_name, message); + } + + void SetPrintTensorType(bool print_tensor_type) { + print_tensor_type_ = print_tensor_type; + } + void SetPrintTensorShape(bool print_tensor_shape) { + print_tensor_shape_ = print_tensor_shape; + } + void SetPrintTensorLod(bool print_tensor_lod) { + print_tensor_lod_ = print_tensor_lod; + } + void SetPrintTensorLayout(bool print_tensor_layout) { + print_tensor_layout_ = print_tensor_layout; + } + void SetSummarize(int64_t summarize) { summarize_ = summarize; } + + private: + template + void FormatData(const Tensor& print_tensor, std::stringstream& log_stream) { + int64_t print_size = summarize_ == -1 + ? print_tensor.numel() + : std::min(summarize_, print_tensor.numel()); + const T* data = print_tensor.data(); // Always kHost, so unnessary to + // copy the data from device + log_stream << " - data: ["; + if (print_size > 0) { + log_stream << data[0]; + for (int64_t i = 1; i < print_size; ++i) { + log_stream << " " << data[i]; + } + } + log_stream << "]" << std::endl; + } + + int64_t summarize_ = -1; + bool print_tensor_type_ = true; + bool print_tensor_shape_ = true; + bool print_tensor_lod_ = true; + bool print_tensor_layout_ = true; +}; + +void PrintCompute::Run() { + auto& param = Param(); + param.out->CopyDataFrom(*param.in); + + if ((param.is_forward && param.print_phase == kBackward) || + (!param.is_forward && param.print_phase == kForward)) { + return; + } + + int first_n = param.first_n; + if (first_n > 0 && ++times_ > first_n) return; + + TensorFormatter formatter; + const std::string& name = param.print_tensor_name ? param.name : ""; + formatter.SetPrintTensorType(param.print_tensor_type); + formatter.SetPrintTensorShape(param.print_tensor_shape); + formatter.SetPrintTensorLod(param.print_tensor_lod); + formatter.SetPrintTensorLayout(param.print_tensor_layout); + formatter.SetSummarize(static_cast(param.summarize)); + formatter.Print(*param.in, name, param.message); +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + print, kHost, kAny, kAny, paddle::lite::kernels::host::PrintCompute, def) + .BindInput("In", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .Finalize(); diff --git a/lite/kernels/host/print_compute.h b/lite/kernels/host/print_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..91a54182d2d2e00250da01fcd5d62556da930198 --- /dev/null +++ b/lite/kernels/host/print_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class PrintCompute + : public KernelLite { + public: + using param_t = operators::PrintParam; + + void Run() override; + + virtual ~PrintCompute() = default; + + private: + mutable int times_{0}; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/while_compute.cc b/lite/kernels/host/while_compute.cc similarity index 50% rename from lite/kernels/arm/while_compute.cc rename to lite/kernels/host/while_compute.cc index 9241fd410a542cef797b57b9341f59895b0f734d..4886b5ffe0f48b231bcef59b5494fc126b8b69e2 100644 --- a/lite/kernels/arm/while_compute.cc +++ b/lite/kernels/host/while_compute.cc @@ -12,44 +12,44 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/while_compute.h" -#include -#include -#include -#include "lite/backends/arm/math/funcs.h" -#include "lite/core/tensor.h" -#include "lite/core/type_system.h" +#include "lite/kernels/host/while_compute.h" +#include +#include namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { void WhileCompute::PrepareForRun() { - auto ¶m = Param(); - auto cur_scope = param.scope; - - executor_ = - std::make_shared(param.sub_block, cur_scope, place()); + auto ¶m = this->Param(); + program_.reset(new RuntimeProgram( + param.program_desc, param.exec_scope, param.block_idx)); } void WhileCompute::Run() { - auto ¶m = Param(); + auto ¶m = this->Param(); while (param.cond->data()[0]) { - executor_->Run(); + program_->Run(); } } -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle REGISTER_LITE_KERNEL( - while, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::WhileCompute, def) - .BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) + while, kHost, kAny, kAny, paddle::lite::kernels::host::WhileCompute, def) + .BindInput("X", + {LiteType::GetTensorListTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .BindInput("Condition", - {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) .BindOutput("Out", - {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) - .BindOutput("StepScopes", {LiteType::GetTensorTy(TARGET(kARM))}) + {LiteType::GetTensorListTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("StepScopes", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .Finalize(); diff --git a/lite/kernels/host/while_compute.h b/lite/kernels/host/while_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..42065865e45c18376034dea0e105bc6d4f1f053f --- /dev/null +++ b/lite/kernels/host/while_compute.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/program.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class WhileCompute + : public KernelLite { + public: + using param_t = operators::WhileParam; + + void Run() override; + void PrepareForRun() override; + + virtual ~WhileCompute() = default; + + private: + std::unique_ptr program_; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/mlu/subgraph_compute.h b/lite/kernels/mlu/subgraph_compute.h index 044827dbf98c561b0d424a1c93b0da650ef58796..75570a6249ecaa36a94b73dafb27f655495cab87 100644 --- a/lite/kernels/mlu/subgraph_compute.h +++ b/lite/kernels/mlu/subgraph_compute.h @@ -43,13 +43,17 @@ class SubgraphEngine : public subgraph::Engine { public: SubgraphEngine(KernelContext* ctx, int block_idx, - cpp::BlockDesc* block_desc, + const std::shared_ptr& program_desc, + Scope* exec_scope, const std::vector& input_names, const std::vector& output_names, - Scope* scope, paddle::lite_api::PrecisionType type) - : subgraph::Engine( - ctx, block_idx, block_desc, input_names, output_names, scope), + : subgraph::Engine(ctx, + block_idx, + program_desc, + exec_scope, + input_names, + output_names), fp_type_(type) { VLOG(4) << "[MLU] PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL is " << GetBoolFromEnv("PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL"); @@ -103,7 +107,7 @@ class SubgraphEngine : public subgraph::Engine { protected: bool BuildDeviceProgram() override { - if (origin_program_.empty()) { + if (!origin_program_) { BuildOriginProgram(); } if (!error_compile_batch_size_changeable_ && @@ -128,13 +132,15 @@ class SubgraphEngine : public subgraph::Engine { origin_itensors_.clear(); origin_otensors_.clear(); - auto data_order = block_desc_->GetOp(0)->Type() == "layout" + auto* sub_block_desc = + program_desc_->GetBlock()(block_idx_); + auto data_order = sub_block_desc->GetOp(0)->Type() == "layout" ? CNML_NCHW : CNML_NHWC; // Convert all of input data vars and added into the MLU IR graph status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED; for (auto& input_name : input_names_) { - auto input_tensor = scope_->FindMutableTensor(input_name); + auto input_tensor = exec_scope_->FindMutableTensor(input_name); auto data_type = input_tensor->precision(); cnmlDataType_t fp_type = PrecisionToDatatype(data_type); origin_itensors_.push_back(input_tensor); @@ -161,7 +167,8 @@ class SubgraphEngine : public subgraph::Engine { LOG(INFO) << "START TO CONVERT "; // Convert all of ops and its weights and added into the MLU IR graph const auto& bridges = subgraph::Registry::Instance(); - for (auto& inst : origin_program_) { + const auto& insts = origin_program_->instructions(kRootBlockIdx); + for (auto& inst : insts) { auto op = inst.op(); CHECK(op); std::string op_type = op->op_info()->Type(); @@ -200,7 +207,7 @@ class SubgraphEngine : public subgraph::Engine { for (auto& output_name : output_names_) { if (graph->HasNode(output_name)) { graph->AddOutput(graph->GetNode(output_name)); - auto output_tensor = scope_->FindMutableTensor(output_name); + auto output_tensor = exec_scope_->FindMutableTensor(output_name); origin_otensors_.push_back(output_tensor); VLOG(4) << "subgraph output tensor " << output_name << std::endl; @@ -257,7 +264,7 @@ class SubgraphEngine : public subgraph::Engine { for (const auto& input_name : input_names_) { tmp = input_name; name += TrimStrings(tmp) + delimiter + input_shape_str; - auto input_tensor = scope_->FindMutableTensor(input_name); + auto input_tensor = exec_scope_->FindMutableTensor(input_name); for (const auto& iterm : input_tensor->dims().Vectorize()) { name += std::to_string(iterm) + delimiter_num; } @@ -266,7 +273,7 @@ class SubgraphEngine : public subgraph::Engine { for (const auto& output_name : output_names_) { tmp = output_name; name += TrimStrings(tmp) + delimiter + output_shape_str; - auto output_tensor = scope_->FindMutableTensor(output_name); + auto output_tensor = exec_scope_->FindMutableTensor(output_name); for (const auto& iterm : output_tensor->dims().Vectorize()) { name += std::to_string(iterm) + delimiter_num; } @@ -284,7 +291,8 @@ class SubgraphEngine : public subgraph::Engine { origin_otensors_[i]->Resize(iter->second[i]); } } else { - for (auto& inst : origin_program_) { + const auto& insts = origin_program_->instructions(kRootBlockIdx); + for (auto& inst : insts) { auto op = inst.op(); CHECK(op); op->CheckShape(); @@ -475,11 +483,11 @@ class SubgraphCompute auto& param = this->template Param(); // LOG(INFO) << "SUBGRAP Prepare RUN index " << param.sub_block_idx; engine_.reset(new SubgraphEngine(this->ctx_.get(), - param.sub_block_idx, - param.sub_block_desc, + param.block_idx, + param.program_desc, + param.exec_scope, param.input_data_names, param.output_data_names, - param.scope, this->precision())); CHECK(engine_); } diff --git a/lite/kernels/npu/bridges/engine.cc b/lite/kernels/npu/bridges/engine.cc index 884ab1acce8f0927def660ae35941d85b4c85901..b9f81a74ad997966ecb79c66bceed1e84b4a91f7 100644 --- a/lite/kernels/npu/bridges/engine.cc +++ b/lite/kernels/npu/bridges/engine.cc @@ -25,11 +25,14 @@ namespace subgraph { Engine::Engine(KernelContext *ctx, int block_idx, - cpp::BlockDesc *block_desc, + const std::shared_ptr &program_desc, + Scope *exec_scope, const std::vector &input_names, - const std::vector &output_names, - lite::Scope *scope) - : ctx_(ctx), block_idx_(block_idx), block_desc_(block_desc), scope_(scope) { + const std::vector &output_names) + : ctx_(ctx), + block_idx_(block_idx), + program_desc_(program_desc), + exec_scope_(exec_scope) { input_names_ = input_names; output_names_ = output_names; // Sort the name of input and output tensors, it's convenient for us to get @@ -55,12 +58,12 @@ bool Engine::PrepareWorkspaceForOriginProgram() { origin_idims_.resize(input_names_.size()); origin_itensors_.resize(input_names_.size()); for (int i = 0; i < input_names_.size(); i++) { - origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); + origin_itensors_[i] = exec_scope_->FindMutableTensor(input_names_[i]); CHECK(origin_itensors_[i]); } origin_otensors_.resize(output_names_.size()); for (int i = 0; i < output_names_.size(); i++) { - origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); + origin_otensors_[i] = exec_scope_->FindMutableTensor(output_names_[i]); CHECK(origin_otensors_[i]); } return true; @@ -69,70 +72,20 @@ bool Engine::PrepareWorkspaceForOriginProgram() { bool Engine::BuildOriginProgram() { // TODO(hong19860320) The block_desc need to be divided into subgraphs during // the exection time. But only see them as a subgraph now. - origin_program_.clear(); - for (size_t op_idx = 0; op_idx < block_desc_->OpsSize(); op_idx++) { - auto op_desc = block_desc_->GetOp(op_idx); - CHECK(op_desc); - std::string op_type = op_desc->Type(); - // Create op and pick up the best kernel - auto op = LiteOpRegistry::Global().Create(op_desc->Type()); - CHECK(op) << "no Op found for " << op_type; - op->Attach(*op_desc, scope_); - std::unique_ptr picked_kernel; - if (op_desc->HasAttr(kKernelTypeAttr)) { - // Create op and pick up the best kernel according to the - // kKernelTypeAttr attribute - auto kernel_type = op_desc->GetAttr(kKernelTypeAttr); - std::string alias; - Place place; - KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); - VLOG(3) << "Found the attr '" << kKernelTypeAttr << "': " << kernel_type - << " for " << op_type; - auto kernels = op->CreateKernels({place}); - CHECK_GT(kernels.size(), 0u) << "No kernels found for " << op_type; - auto it = std::find_if( - kernels.begin(), kernels.end(), [&](std::unique_ptr &it) { - return it->alias() == alias; - }); - CHECK(it != kernels.end()); - picked_kernel = std::move(*it); - } else { - // TODO(hong19860320) add kernel picking according to the type of input - // and output tensors - VLOG(3) << "The attr '" << kKernelTypeAttr - << "' not found, pick the first kernel for " << op_type; - std::vector> kernels; -#if defined(LITE_WITH_ARM) - kernels = op->CreateKernels({Place{TARGET(kARM)}, Place{TARGET(kHost)}}); -#elif defined(LITE_WITH_X86) - kernels = op->CreateKernels({Place{TARGET(kX86)}, Place{TARGET(kHost)}}); -#endif - if (kernels.size() > 0) { - picked_kernel = std::move(kernels.front()); - } else { - LOG(WARNING) << "No kernels found for " << op_type; - } - } - if (picked_kernel != nullptr) { - picked_kernel->SetContext( - ContextScheduler::Global().NewContext(picked_kernel->target())); - } - origin_program_.emplace_back(std::move(op), std::move(picked_kernel)); + if (!origin_program_) { + origin_program_.reset( + new RuntimeProgram(program_desc_, exec_scope_, block_idx_)); } - CHECK(!origin_program_.empty()) << "no instructions"; return true; } bool Engine::LaunchOriginProgram() { - if (origin_program_.empty()) { + if (!origin_program_) { BuildOriginProgram(); } - if (!origin_program_.empty()) { - for (auto &inst : origin_program_) { - auto op_type = inst.op()->op_info()->Type(); - if (op_type == "feed" || op_type == "fetch") continue; - inst.Run(); - } + if (origin_program_) { + VLOG(3) << "Roll back to run the origin program."; + origin_program_->Run(); return true; } return false; diff --git a/lite/kernels/npu/bridges/engine.h b/lite/kernels/npu/bridges/engine.h index b49b8fea5a6d39610ea7398e177e7d1ec5a35f92..daa02fb0d7bf8f70ebf8b21821a274b6a0ba062d 100644 --- a/lite/kernels/npu/bridges/engine.h +++ b/lite/kernels/npu/bridges/engine.h @@ -30,10 +30,10 @@ class Engine { public: Engine(KernelContext *ctx, int block_idx, - cpp::BlockDesc *block_desc, + const std::shared_ptr &program_desc, + Scope *exec_scope, const std::vector &input_names, - const std::vector &output_names, - lite::Scope *scope); + const std::vector &output_names); virtual ~Engine() = default; virtual bool Run(); @@ -54,15 +54,15 @@ class Engine { KernelContext *ctx_{nullptr}; int block_idx_{-1}; - cpp::BlockDesc *block_desc_{nullptr}; + const std::shared_ptr program_desc_{nullptr}; std::vector input_names_; std::vector output_names_; - Scope *scope_{nullptr}; + Scope *exec_scope_{nullptr}; bool is_first_epoch_{true}; std::vector> origin_idims_; std::vector origin_itensors_; std::vector origin_otensors_; - std::vector origin_program_; + std::unique_ptr origin_program_{nullptr}; }; } // namespace subgraph diff --git a/lite/kernels/npu/subgraph_compute.cc b/lite/kernels/npu/subgraph_compute.cc index 6afb445e0ed411251d203bcb0420b0fba8ab6beb..e9c5957ff6d8f026f712de04f4e32cd69baf50a9 100644 --- a/lite/kernels/npu/subgraph_compute.cc +++ b/lite/kernels/npu/subgraph_compute.cc @@ -55,7 +55,8 @@ std::string DeviceProgram::GenerateModelName( } // Deserialize the generated model, the precisions and dimensions of the origin -// output tensors of the subgraph op into files +// output tensors of the subgraph op from the cached configuration file and HiAI +// om file bool DeviceProgram::LoadFromCacheFile( const std::vector& input_names, const std::vector& output_names, @@ -71,7 +72,7 @@ bool DeviceProgram::LoadFromCacheFile( VLOG(3) << "[NPU] Load model from " << model_path; std::vector model_buffer; if (!ReadFile(model_path, &model_buffer)) { - LOG(WARNING) << "[NPU] read from " << model_path << " failed!"; + LOG(WARNING) << "[NPU] Open " << model_path << " for reading failed!"; return false; } bool model_comp = false; @@ -98,9 +99,9 @@ bool DeviceProgram::LoadFromCacheFile( LOG(WARNING) << "[NPU] read from " << config_path << " failed!"; return false; } - std::string config_str(config_buffer.begin(), config_buffer.end()); + std::string str(config_buffer.begin(), config_buffer.end()); // Parse the precision and shapes of the output tensors - auto output_options = Split(config_str, ";"); + auto output_options = Split(str, ";"); CHECK_EQ(output_options.size(), output_names.size()); origin_otypes_.resize(output_names.size()); origin_odims_.resize(output_names.size()); @@ -114,7 +115,7 @@ bool DeviceProgram::LoadFromCacheFile( } bool DeviceProgram::BuildGraphAndCacheToFile( - const std::vector& origin_program, + RuntimeProgram* origin_program, const std::vector& input_names, const std::vector& output_names, const std::vector>& origin_idims, @@ -127,10 +128,13 @@ bool DeviceProgram::BuildGraphAndCacheToFile( // Convert all of ops and their input vars and weights to HiAI IR nodes, // then added them into the HiAI IR graph int status = 0; - CHECK(!origin_program.empty()) << "no instructions"; subgraph::npu::Graph graph; const auto& bridges = subgraph::Registry::Instance(); - for (auto& inst : origin_program) { + CHECK(origin_program) << "[NPU] The origin program is not initialized!"; + CHECK_GT(origin_program->instructions(kRootBlockIdx).size(), 0) + << "[NPU] No instructions found in the origin program!"; + const auto& insts = origin_program->instructions(kRootBlockIdx); + for (auto& inst : insts) { auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); @@ -149,7 +153,8 @@ bool DeviceProgram::BuildGraphAndCacheToFile( // Collect the input and output nodes of the HiAI IR graph std::vector device_inodes; for (size_t i = 0; i < input_names.size(); i++) { - CHECK(graph.Has(input_names[i]) && graph.Get(input_names[i])->is_data()); + CHECK(graph.Has(input_names[i])); + CHECK(graph.Get(input_names[i])->is_data()); device_inodes.push_back(*graph.Get(input_names[i])->data()); } std::vector device_onodes; @@ -173,6 +178,9 @@ bool DeviceProgram::BuildGraphAndCacheToFile( LOG(WARNING) << "[NPU] Load model failed!"; return false; } + // Do not check model compatibility because it assume that the cached om model + // is always compatible with the current device + // Update the precison and dimensions of the origin output tensors // Update the precison and dimensions of the origin output tensors CHECK_EQ(origin_otensors.size(), output_names.size()); origin_otypes_.resize(output_names.size()); @@ -247,7 +255,7 @@ bool DeviceProgram::ShareBufferWithOriginTensors( device_idims_[i].GetHeight() * device_idims_[i].GetWidth()); VLOG(3) << "[NPU] Init the input tensors for the device program and share " "their buffers with the origin input tensors"; - // reinit device tensor will free shared buffer, so copy data to a tmp + // Reinit device tensor will free shared buffer, so copy data to a tmp // tensor Tensor tmp; tmp.CopyDataFrom(*(*origin_itensors)[i]); @@ -337,8 +345,9 @@ bool SubgraphEngine::BuildDeviceProgram() { if (!device_programs_.count(origin_idims_)) { auto device_program = std::make_shared(); // Obtain the model cache dir from the NPU Context of the subgraph op - auto model_cache_dir = ctx_->As().SubgraphModelCacheDir(); - VLOG(3) << "[NPU] Getting subgraph model_cache_dir is: " << model_cache_dir; + auto model_cache_dir = + ctx_->As().SubgraphModelCacheDir(exec_scope_); + VLOG(3) << "[NPU] Getting subgraph_model_cache_dir: " << model_cache_dir; // Check and load if the cached model and configuration file exists if (model_cache_dir.empty() || !device_program->LoadFromCacheFile( @@ -346,11 +355,13 @@ bool SubgraphEngine::BuildDeviceProgram() { // Build the model online, including converting the paddle ops to the HiAI // IR nodes, building the HiAI IR graph to the om model, then load it as a // new HiAI model manager client for inference. - if (origin_program_.empty()) { + if (!origin_program_) { BuildOriginProgram(); } - CHECK(!origin_program_.empty()) << "no instructions"; - if (!device_program->BuildGraphAndCacheToFile(origin_program_, + CHECK(origin_program_) << "[NPU] The origin program is not initialized!"; + CHECK_GT(origin_program_->instructions().size(), 0) + << "[NPU] No instructions found in the origin program!"; + if (!device_program->BuildGraphAndCacheToFile(origin_program_.get(), input_names_, output_names_, origin_idims_, @@ -391,11 +402,11 @@ bool SubgraphEngine::LaunchDeviceProgram() { void SubgraphCompute::PrepareForRun() { auto& param = this->Param(); engine_.reset(new SubgraphEngine(ctx_.get(), - param.sub_block_idx, - param.sub_block_desc, + param.block_idx, + param.program_desc, + param.exec_scope, param.input_data_names, - param.output_data_names, - param.scope)); + param.output_data_names)); CHECK(engine_); } diff --git a/lite/kernels/npu/subgraph_compute.h b/lite/kernels/npu/subgraph_compute.h index 33321a7789fbc1eee5ff759dcf682d8e875ffe96..2203acaee82704b2a9e93d8b14d708197d7afb1a 100644 --- a/lite/kernels/npu/subgraph_compute.h +++ b/lite/kernels/npu/subgraph_compute.h @@ -41,7 +41,7 @@ class DeviceProgram { const std::vector>& origin_idims, const std::string& model_cache_dir); bool BuildGraphAndCacheToFile( - const std::vector& origin_program, + RuntimeProgram* origin_program, const std::vector& input_names, const std::vector& output_names, const std::vector>& origin_idims, @@ -71,12 +71,16 @@ class SubgraphEngine : public subgraph::Engine { public: SubgraphEngine(KernelContext* ctx, int block_idx, - cpp::BlockDesc* block_desc, + const std::shared_ptr& program_desc, + Scope* exec_scope, const std::vector& input_names, - const std::vector& output_names, - Scope* scope) - : subgraph::Engine( - ctx, block_idx, block_desc, input_names, output_names, scope) {} + const std::vector& output_names) + : subgraph::Engine(ctx, + block_idx, + program_desc, + exec_scope, + input_names, + output_names) {} protected: bool PrepareWorkspaceForDeviceProgram() override; diff --git a/lite/kernels/opencl/conv_image_compute.h b/lite/kernels/opencl/conv_image_compute.h index 4eab7be1f1ac6459250c6df984160f0f6060ea1c..e61557a71dfbf1353decc9491b67c5e1e326512e 100644 --- a/lite/kernels/opencl/conv_image_compute.h +++ b/lite/kernels/opencl/conv_image_compute.h @@ -152,7 +152,7 @@ class ConvImageCompute : public KernelLite(1), static_cast(1), static_cast(1)}; bool use_lws_{true}; - bool use_tune_{false}; + bool use_tune_{true}; }; } // namespace opencl diff --git a/lite/kernels/opencl/nearest_interp_image_compute_test.cc b/lite/kernels/opencl/nearest_interp_image_compute_test.cc index 4a9948832d1a96d95a7f317bd3ac8245292ae02b..fb40da290d10ed49f293cf7ff78865f2e7967eab 100644 --- a/lite/kernels/opencl/nearest_interp_image_compute_test.cc +++ b/lite/kernels/opencl/nearest_interp_image_compute_test.cc @@ -155,6 +155,7 @@ TEST(nearest_interp_image2d, compute) { auto *x_data = x.mutable_data(TARGET(kOpenCL)); auto *y_data = y.mutable_data(TARGET(kOpenCL)); auto *y_data_ref = y_ref.mutable_data(TARGET(kARM)); + memset(reinterpret_cast(y_data_ref), 0, y_ref.numel()); auto *mapped_x = static_cast(TargetWrapperCL::Map( x_data, 0, sizeof(float) * x_dim.production())); auto *mapped_y = static_cast(TargetWrapperCL::Map( diff --git a/lite/kernels/rknpu/subgraph_compute.cc b/lite/kernels/rknpu/subgraph_compute.cc index a50505c38c0740f762256cd71e006caf9249838e..da01539b291d57da1501f8c3790acae8496581f3 100644 --- a/lite/kernels/rknpu/subgraph_compute.cc +++ b/lite/kernels/rknpu/subgraph_compute.cc @@ -28,26 +28,6 @@ namespace lite { namespace kernels { namespace rknpu { -bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() { - // Obtain the origin input tensors, and create the origin output - // tensors(Don't try to access them before launch the device program or the - // origin program) - PrepareWorkspaceForOriginProgram(); - // Create the device input and output tensors, but don't initialize them - // with the dimensions - device_itensors_.resize(input_names_.size()); - for (int i = 0; i < input_names_.size(); i++) { - device_itensors_[i].reset(new hiai::AiTensor); - CHECK(device_itensors_[i]); - } - device_otensors_.resize(output_names_.size()); - for (int i = 0; i < output_names_.size(); i++) { - device_otensors_[i].reset(new hiai::AiTensor); - CHECK(device_otensors_[i]); - } - return true; -} - bool SubgraphEngine::BuildDeviceProgram() { LOG(INFO) << "[RKNPU]:BuildDeviceProgram"; int status = 0; @@ -55,10 +35,11 @@ bool SubgraphEngine::BuildDeviceProgram() { // RKNPU IR graph subgraph::rknpu::Graph graph; const auto& bridges = subgraph::Registry::Instance(); - if (origin_program_.empty()) { + if (!origin_program_) { BuildOriginProgram(); } - for (auto& inst : origin_program_) { + const auto& insts = origin_program_->instructions(kRootBlockIdx); + for (auto& inst : insts) { auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); @@ -76,92 +57,26 @@ bool SubgraphEngine::BuildDeviceProgram() { } // Collect the valid input and output nodes in the RKNPU IR graph and update // the input and output names - device_inames_.clear(); - device_onames_.clear(); - - for (auto& input_name : input_names_) { - LOG(INFO) << "[RKNPU] Input node " << input_name; - if (graph.Has(input_name)) { - LOG(INFO) << input_name << " Precision " - << PrecisionToStr(graph.Get(input_name)->precision()); - device_itensors_.push_back(graph.Get(input_name)->data()); - device_inames_.push_back(input_name); - } else { - LOG(WARNING) << "[RKNPU] Input node " << input_name - << " is ignored because it does not exist."; - } - } - - for (auto& output_name : output_names_) { - LOG(INFO) << "[RKNPU] Output node " << output_name; - if (graph.Has(output_name)) { - auto tensor = scope_->FindMutableTensor(output_name); - LOG(INFO) << output_name << " Precision " - << PrecisionToStr(tensor->precision()); - device_otensors_.push_back(graph.Get(output_name)->data()); - device_onames_.push_back(output_name); - } else { - LOG(WARNING) << "[RKNPU] Output node " << output_name - << " is ignored because it does not exist."; - } - } - CHECK(!device_inames_.empty()) - << "[RKNPU] No input nodes found for building NPU model"; - CHECK(!device_onames_.empty()) - << "[RKNPU] No output nodes found for building NPU model"; - - device_program_ = lite::rknpu::Device::Global().Build( - model_name_, graph.GetHandle(), device_itensors_, device_otensors_); - if (device_program_ == nullptr) { - LOG(WARNING) << "[RKNPU] Build model failed!"; - return false; - } - - // input - origin_idims_.resize(input_names_.size()); - origin_itensors_.resize(input_names_.size()); + device_itensors_.clear(); + device_otensors_.clear(); for (size_t i = 0; i < input_names_.size(); i++) { - origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); - CHECK(origin_itensors_[i]); - origin_idims_[i] = origin_itensors_[i]->dims(); - } - // output - origin_odims_.resize(output_names_.size()); - origin_otensors_.resize(output_names_.size()); - for (size_t i = 0; i < output_names_.size(); i++) { - origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); - CHECK(origin_otensors_[i]); - origin_odims_[i] = origin_otensors_[i]->dims(); - - auto output_dims = origin_otensors_[i]->dims(); - } - - origin_idims_.resize(device_inames_.size()); - origin_itensors_.resize(device_inames_.size()); - device_itensors_.resize(device_inames_.size()); - origin_odims_.resize(device_onames_.size()); - origin_otensors_.resize(device_onames_.size()); - device_otensors_.resize(device_onames_.size()); - for (int i = 0; i < device_inames_.size(); i++) { - auto node = graph.Get(device_inames_[i]); + CHECK(graph.Has(input_names_[i])) << "[RKNPU] Failed to find input node " + << input_names_[i]; + auto node = graph.Get(input_names_[i]); auto precision = node->precision(); auto layout = node->layout(); - origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]); - CHECK(origin_itensors_[i]); - origin_idims_[i] = origin_itensors_[i]->dims(); - - LOG(INFO) << "[RKNPU] Inputs[" << i << "] name: " << device_inames_[i] + LOG(INFO) << "[RKNPU] Inputs[" << i << "] name: " << input_names_[i] << " precision: " << PrecisionToStr(precision) << " layout: " << DataLayoutToStr(layout); + device_itensors_.push_back(node->data()); } - for (int i = 0; i < device_onames_.size(); i++) { - auto node = graph.Get(device_onames_[i]); + for (size_t i = 0; i < output_names_.size(); i++) { + CHECK(graph.Has(output_names_[i])) << "[RKNPU] Failed to find output node " + << output_names_[i]; + auto node = graph.Get(output_names_[i]); auto precision = node->precision(); auto layout = node->layout(); - origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]); - CHECK(origin_otensors_[i]); - origin_odims_[i] = origin_otensors_[i]->dims(); - LOG(INFO) << "[RKNPU] Outputs[" << i << "] name: " << device_onames_[i] + LOG(INFO) << "[RKNPU] Outputs[" << i << "] name: " << output_names_[i] << " precision: " << PrecisionToStr(precision) << " layout: " << DataLayoutToStr(layout); // Prepare the device output tensors @@ -182,11 +97,19 @@ bool SubgraphEngine::BuildDeviceProgram() { origin_otensors_[i]->mutable_data(); break; default: - LOG(FATAL) << "[RKNPU] " << device_onames_[i] + LOG(FATAL) << "[RKNPU] " << output_names_[i] << " can't mutable data with precision type " << PrecisionToStr(precision); break; } + device_otensors_.push_back(node->data()); + } + // Create the RKNPU model and set the input and output nodes + device_program_ = lite::rknpu::Device::Global().Build( + model_name_, graph.GetHandle(), device_itensors_, device_otensors_); + if (device_program_ == nullptr) { + LOG(WARNING) << "[RKNPU] Build model failed!"; + return false; } return true; } @@ -196,8 +119,8 @@ bool SubgraphEngine::LaunchDeviceProgram() { std::vector inputs; std::vector outputs; - inputs.resize(device_itensors_.size()); - for (size_t i = 0; i < device_itensors_.size(); i++) { + inputs.resize(origin_itensors_.size()); + for (size_t i = 0; i < origin_itensors_.size(); i++) { inputs[i].index = i; inputs[i].buf = const_cast(origin_itensors_[i]->raw_data()); inputs[i].size = origin_itensors_[i]->memory_size(); @@ -207,8 +130,8 @@ bool SubgraphEngine::LaunchDeviceProgram() { inputs[i].layout = rk::nn::DataLayoutType::NCHW; } - outputs.resize(device_otensors_.size()); - for (size_t i = 0; i < device_otensors_.size(); i++) { + outputs.resize(origin_otensors_.size()); + for (size_t i = 0; i < origin_otensors_.size(); i++) { outputs[i].index = i; outputs[i].buf = const_cast(origin_otensors_[i]->raw_data()); outputs[i].size = origin_otensors_[i]->memory_size(); @@ -225,11 +148,11 @@ void SubgraphCompute::PrepareForRun() { LOG(INFO) << "[RKNPU]:PrepareForRun"; auto& param = this->Param(); engine_.reset(new SubgraphEngine(ctx_.get(), - param.sub_block_idx, - param.sub_block_desc, + param.block_idx, + param.program_desc, + param.exec_scope, param.input_data_names, - param.output_data_names, - param.scope)); + param.output_data_names)); CHECK(engine_); } diff --git a/lite/kernels/rknpu/subgraph_compute.h b/lite/kernels/rknpu/subgraph_compute.h index a4bdadc658a81decd8107072f7b5948613d0c68a..78162b3d165bde8e33436654bbcd1110ad9afea6 100644 --- a/lite/kernels/rknpu/subgraph_compute.h +++ b/lite/kernels/rknpu/subgraph_compute.h @@ -34,15 +34,18 @@ class SubgraphEngine : public subgraph::Engine { public: SubgraphEngine(KernelContext *ctx, int block_idx, - cpp::BlockDesc *block_desc, + const std::shared_ptr &program_desc, + Scope *exec_scope, const std::vector &input_names, - const std::vector &output_names, - Scope *scope) - : subgraph::Engine( - ctx, block_idx, block_desc, input_names, output_names, scope) {} + const std::vector &output_names) + : subgraph::Engine(ctx, + block_idx, + program_desc, + exec_scope, + input_names, + output_names) {} protected: - bool PrepareWorkspaceForDeviceProgram() override; bool BuildDeviceProgram() override; bool LaunchDeviceProgram() override; diff --git a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc index 376cdd0dc23426ede42ddac60e061727f73322e3..224bfdc130338bc653091400708bc8a7421a9482 100644 --- a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc +++ b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc @@ -31,11 +31,14 @@ void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() { CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */ table_lens_cpu_.push_back(table_dims[0]); } - void* lens_ptr = nullptr; + size_t lens_size = table_lens_cpu_.size() * sizeof(int); - xpu_malloc(&lens_ptr, lens_size); - xpu_memcpy(lens_ptr, &table_lens_cpu_[0], lens_size, XPU_HOST_TO_DEVICE); - table_lens_guard_.reset(lens_ptr); + table_lens_guard_ = + TargetWrapperXPU::MallocScratchPad(lens_size, false /* use_l3 */); + XPU_CALL(xpu_memcpy(table_lens_guard_->addr_, + &table_lens_cpu_[0], + lens_size, + XPU_HOST_TO_DEVICE)); } void XPUEmbeddingWithEltwiseAddCompute::Run() { @@ -55,16 +58,16 @@ void XPUEmbeddingWithEltwiseAddCompute::Run() { int embed_dim = table_dims[1]; int emb_layer_num = param.Ids.size(); int r = xdnn::embedding_with_ewadd( - ctx.GetRawContext(), /* context */ - embed_dim, /* embed_dim */ - idx_len, /* idx_len */ - emb_layer_num, /* emb_layer_num */ - param.padding_idx, /* padding_idx */ - &arg_tables_[0], /* tables */ - &arg_ids_[0], /* indices */ - static_cast(table_lens_guard_.get()), /* table_lens */ - nullptr, /* scale_after_emb */ - nullptr, /* scale_after_ewadd */ + ctx.GetRawContext(), /* context */ + embed_dim, /* embed_dim */ + idx_len, /* idx_len */ + emb_layer_num, /* emb_layer_num */ + param.padding_idx, /* padding_idx */ + &arg_tables_[0], /* tables */ + &arg_ids_[0], /* indices */ + static_cast(table_lens_guard_->addr_), /* table_lens */ + nullptr, /* scale_after_emb */ + nullptr, /* scale_after_ewadd */ param.Out->mutable_data(TARGET(kXPU)) /* top */); CHECK_EQ(r, 0); } diff --git a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h index 10ba6e0b5b76a1dbebfd633732f7c36e6ac7c954..124ed7866f0a52b892e30ae41398d5140064c964 100644 --- a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h +++ b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h @@ -14,10 +14,9 @@ #pragma once -#include #include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard #include "lite/core/kernel.h" -#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter namespace paddle { namespace lite { @@ -36,7 +35,7 @@ class XPUEmbeddingWithEltwiseAddCompute private: std::vector arg_ids_; std::vector arg_tables_; - std::unique_ptr table_lens_guard_; + XPUScratchPadGuard table_lens_guard_; std::vector table_lens_cpu_; }; diff --git a/lite/kernels/xpu/__xpu__mmdnn_compute.cc b/lite/kernels/xpu/__xpu__mmdnn_compute.cc index 39ddecb1139073cb1a0bd8e3c7afc89f1d739da8..09d59fcee37c634a87636ac80e7be15d927f2509 100644 --- a/lite/kernels/xpu/__xpu__mmdnn_compute.cc +++ b/lite/kernels/xpu/__xpu__mmdnn_compute.cc @@ -27,8 +27,8 @@ namespace { void FillMax(float max, float* xpu_ptr) { float maxs[4] = {max, 0.0f, 0.0f, 0.0f}; - xpu_memcpy( - xpu_ptr, maxs, 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy( + xpu_ptr, maxs, 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); } void GrnnLayout(int batch, @@ -156,8 +156,8 @@ class MMDNNIdInfo { idx_sorted.data(), idx_sorted.size() * sizeof(int)); offset += idx_sorted.size() * sizeof(int); - xpu_memcpy( - l3_buffer_, cpu_buffer_, offset, XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy( + l3_buffer_, cpu_buffer_, offset, XPUMemcpyKind::XPU_HOST_TO_DEVICE)); } }; @@ -221,29 +221,32 @@ class MMDNNFcOp { int m, float* out, const float* in_max_by_caller = nullptr) { + int r = 0; if (in_max_by_caller == nullptr) { - xdnn::findmax(ctx, in, m * k_, in_max_); + r = xdnn::findmax(ctx, in, m * k_, in_max_); + CHECK_EQ(r, 0); in_max_by_caller = in_max_; } - xdnn::gemm_int16_maxptr(ctx, - false, - true, - m, - n_, - k_, - 1.0f, - in, - k_, - weight_, - k_, - 0.0f, - out, - n_, - bias_, - act_type_, - in_max_by_caller, - weight_max_, - out_max); + r = xdnn::gemm_int16_maxptr(ctx, + false, + true, + m, + n_, + k_, + 1.0f, + in, + k_, + weight_, + k_, + 0.0f, + out, + n_, + bias_, + act_type_, + in_max_by_caller, + weight_max_, + out_max); + CHECK_EQ(r, 0); } }; @@ -331,44 +334,49 @@ class MMDNNGrnnOp { gru_out = l3_buffer + 4 * slot_size; } - xdnn::search_seq2batch(ctx, - batch, - max_width, - cap_e_, - sentense.idx_sorted_32, - sentense.lod_32, - sentense.new_offset_32, - in, - seq2batch_out); - - xdnn::findmax(ctx, in, cap_l * cap_e_, input_max_); + int r = 0; + r = xdnn::search_seq2batch(ctx, + batch, + max_width, + cap_e_, + sentense.idx_sorted_32, + sentense.lod_32, + sentense.new_offset_32, + in, + seq2batch_out); + CHECK_EQ(r, 0); + + r = xdnn::findmax(ctx, in, cap_l * cap_e_, input_max_); + CHECK_EQ(r, 0); fc_e2h0_.Infer(ctx, seq2batch_out, cap_l, fc_e2h_out, input_max_); fc_e2h1_.Infer( ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_, input_max_); fc_e2h2_.Infer( ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_ * 2, input_max_); - xdnn::search_grnn(ctx, - cap_l, - cap_h_, - cap_e_, - max_width, - sentense.new_offset_32, - fc_e2h_out, - dense_h2h_, - gru_out, - dense_h2h_max_[0], - dense_h2h_max_[1], - dense_h2h_max_[2]); - - xdnn::search_batch2seq(ctx, - batch, - max_width, - cap_h_, - sentense.idx_sorted_32, - sentense.lod_32, - sentense.new_offset_32, - gru_out, - out); + r = xdnn::search_grnn(ctx, + cap_l, + cap_h_, + cap_e_, + max_width, + sentense.new_offset_32, + fc_e2h_out, + dense_h2h_, + gru_out, + dense_h2h_max_[0], + dense_h2h_max_[1], + dense_h2h_max_[2]); + CHECK_EQ(r, 0); + + r = xdnn::search_batch2seq(ctx, + batch, + max_width, + cap_h_, + sentense.idx_sorted_32, + sentense.lod_32, + sentense.new_offset_32, + gru_out, + out); + CHECK_EQ(r, 0); } }; @@ -435,38 +443,43 @@ class MMDNNAttentionOp { } seqfc_.Infer(ctx, input, cap_l, seqfc_out); - xdnn::search_noaligned_mat_mul(ctx, - 0, - 1, - batch, - lod_32, - max_width, - dim_, - alpha0_, - input, - seqfc_out, - batchgemm0_out); - xdnn::search_seq_softmax( + int r = 0; + r = xdnn::search_noaligned_mat_mul(ctx, + 0, + 1, + batch, + lod_32, + max_width, + dim_, + alpha0_, + input, + seqfc_out, + batchgemm0_out); + CHECK_EQ(r, 0); + r = xdnn::search_seq_softmax( ctx, batchgemm0_out, seq_softmax_out, lod_32, batch, max_width); - xdnn::search_noaligned_mat_mul(ctx, - 0, - 0, - batch, - lod_32, - max_width, - dim_, - alpha1_, - seq_softmax_out, - input, - batchgemm1_out); - xdnn::sequence_pooling_forward(ctx, - xdnn::Pooling_t::MAX_WITHOUT_INDEX, - batch, - lod_32, - dim_, - batchgemm1_out, - nullptr, - pool_out); + CHECK_EQ(r, 0); + r = xdnn::search_noaligned_mat_mul(ctx, + 0, + 0, + batch, + lod_32, + max_width, + dim_, + alpha1_, + seq_softmax_out, + input, + batchgemm1_out); + CHECK_EQ(r, 0); + r = xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::MAX_WITHOUT_INDEX, + batch, + lod_32, + dim_, + batchgemm1_out, + nullptr, + pool_out); + CHECK_EQ(r, 0); } }; @@ -510,12 +523,13 @@ class MMDNNMatchConvTopk { float conv_w_max, int dim_t, int dim_in, + int out_channel, int upper_bound_batch, int upper_bound_seqlen, const std::vector& topks) { dim_t_ = dim_t; dim_in_ = dim_in; - out_channel_ = 5; // TODO(miaotianxiang): + out_channel_ = out_channel; topks_ = topks; xw_fc_.Init(input_w, @@ -553,10 +567,10 @@ class MMDNNMatchConvTopk { topks_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(topks_.size() * sizeof(int), false); topks_xpu_ = reinterpret_cast(topks_xpu_guard_->addr_); - xpu_memcpy(topks_xpu_, - topks_.data(), - topks_.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(topks_xpu_, + topks_.data(), + topks_.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); useless_topk_pos_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(int), false); useless_topk_pos_ = reinterpret_cast(useless_topk_pos_guard_->addr_); @@ -576,18 +590,18 @@ class MMDNNMatchConvTopk { for (auto e : left_lod) { left_lod_32_cpu.push_back(e); } - xpu_memcpy(left_lod_32_, - left_lod_32_cpu.data(), - left_lod_32_cpu.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(left_lod_32_, + left_lod_32_cpu.data(), + left_lod_32_cpu.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); std::vector right_lod_32_cpu; for (auto e : right_lod) { right_lod_32_cpu.push_back(e); } - xpu_memcpy(right_lod_32_, - right_lod_32_cpu.data(), - right_lod_32_cpu.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(right_lod_32_, + right_lod_32_cpu.data(), + right_lod_32_cpu.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); std::vector lod_match = {0}; std::vector lod_conv = {0}; @@ -611,18 +625,18 @@ class MMDNNMatchConvTopk { left_seqlen_sum += len_x; right_seqlen_sum += len_y; } - xpu_memcpy(match_lod_32_, - lod_match.data(), - lod_match.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(conv_lod_32_, - lod_conv.data(), - lod_conv.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(topk_offset_32_, - lod_topk.data(), - lod_topk.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(match_lod_32_, + lod_match.data(), + lod_match.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(conv_lod_32_, + lod_conv.data(), + lod_conv.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(topk_offset_32_, + lod_topk.data(), + lod_topk.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); float* xwy_out = hbm_buffer_; float* conv_out = hbm_buffer_ + x_mul_y_sum * dim_t_; @@ -640,19 +654,21 @@ class MMDNNMatchConvTopk { int max_width = std::max(left_seqlen_max, right_seqlen_max); xw_fc_.Infer(ctx, left->data(), left_seqlen_sum, xw_out); - xdnn::match_matrix_tensor(ctx, - batch, - xw_out, - right->data(), - left_lod_32_, - right_lod_32_, - dim_t_, - dim_in_, - xwy_out, - xw_fc_.out_max, - xdnn::Activation_t::RELU, - max_width); - xdnn::search_varconv( + int r = 0; + r = xdnn::match_matrix_tensor(ctx, + batch, + xw_out, + right->data(), + left_lod_32_, + right_lod_32_, + dim_t_, + dim_in_, + xwy_out, + xw_fc_.out_max, + xdnn::Activation_t::RELU, + max_width); + CHECK_EQ(r, 0); + r = xdnn::search_varconv( ctx, batch, dim_t_, @@ -668,24 +684,27 @@ class MMDNNMatchConvTopk { conv_out, conv_weight_max_, xdnn::Activation_t::RELU); // TODO(miaotianxiang): - xdnn::sequence_concat(ctx, - xwy_out, - match_lod_32_, - conv_out, - conv_lod_32_, - seq_concat_out, - batch); - xdnn::sequence_topk_avg_pooling(ctx, - seq_concat_out, - seq_avg_topk_out, - useless_topk_pos_, - batch, - dim_t_ + out_channel_, - topk_offset_32_, - left_lod_32_, - right_lod_32_, - topks_xpu_, - topks_.size()); + CHECK_EQ(r, 0); + r = xdnn::sequence_concat(ctx, + xwy_out, + match_lod_32_, + conv_out, + conv_lod_32_, + seq_concat_out, + batch); + CHECK_EQ(r, 0); + r = xdnn::sequence_topk_avg_pooling(ctx, + seq_concat_out, + seq_avg_topk_out, + useless_topk_pos_, + batch, + dim_t_ + out_channel_, + topk_offset_32_, + left_lod_32_, + right_lod_32_, + topks_xpu_, + topks_.size()); + CHECK_EQ(r, 0); } }; @@ -802,34 +821,38 @@ class MMDNNBidEmbGrnnAtt { pool_rv = grnn_rv_pool_out->mutable_data(TARGET(kXPU)); att_out = att_pool_out->mutable_data(TARGET(kXPU)); - xdnn::search_bid_emb_ew(ctx, - batch, - sentense.lod_64, - sentense.id0_64, - sentense.id1_64, - table_, - table_len_, - emb_dim_, - emb_fw, - emb_rv, - table_len_ - 2, - 1); + int r = 0; + r = xdnn::search_bid_emb_ew(ctx, + batch, + sentense.lod_64, + sentense.id0_64, + sentense.id1_64, + table_, + table_len_, + emb_dim_, + emb_fw, + emb_rv, + table_len_ - 2, + 1); + CHECK_EQ(r, 0); bi_rv_.Infer(ctx, sentense, emb_rv, grnn_rv, l3_buffer + 2 * slot_len, l3_size - 2 * slot_len * sizeof(float)); - xdnn::sequence_reverse( + r = xdnn::sequence_reverse( ctx, batch, sentense.lod_32, cap_h_, grnn_rv, grnn_rv_rv); - xdnn::sequence_pooling_forward(ctx, - xdnn::Pooling_t::LAST, - batch, - sentense.lod_32, - cap_h_, - grnn_rv, - nullptr, - pool_rv); + CHECK_EQ(r, 0); + r = xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::LAST, + batch, + sentense.lod_32, + cap_h_, + grnn_rv, + nullptr, + pool_rv); + CHECK_EQ(r, 0); bi_fw_.Infer(ctx, sentense, @@ -837,19 +860,23 @@ class MMDNNBidEmbGrnnAtt { grnn_fw, l3_buffer + 2 * slot_len, l3_size - 2 * slot_len * sizeof(float)); - xdnn::sequence_pooling_forward(ctx, - xdnn::Pooling_t::LAST, - batch, - sentense.lod_32, - cap_h_, - grnn_fw, - nullptr, - pool_fw); + r = xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::LAST, + batch, + sentense.lod_32, + cap_h_, + grnn_fw, + nullptr, + pool_fw); + CHECK_EQ(r, 0); const int concat_widths[] = {cap_h_, cap_h_, cap_h_}; const float* concat_ptrs[] = {emb_fw, grnn_fw, grnn_rv_rv}; - xdnn::concat( + r = xdnn::concat( ctx, cap_l, concat_widths + 1, 2, concat_ptrs + 1, concat_2in); - xdnn::concat(ctx, cap_l, concat_widths, 3, concat_ptrs, concat_3in); + CHECK_EQ(r, 0); + r = xdnn::concat( + ctx, cap_l, concat_widths, 3, concat_ptrs, concat_3in); + CHECK_EQ(r, 0); att_.Infer(ctx, sentense, concat_2in, @@ -899,16 +926,18 @@ class MMDNNEmbAtt { int cap_l = sentense.lod.back(); const float* emb_tables[] = {table_, table_}; const int64_t* emb_indices[] = {sentense.id0_64, sentense.id1_64}; - xdnn::embedding_with_ewadd(ctx, - emb_dim_, - cap_l, - 2, - table_len_ - 2, - emb_tables, - emb_indices, - nullptr, - nullptr, - emb_fw); + int r = + xdnn::embedding_with_ewadd(ctx, + emb_dim_, + cap_l, + 2, + table_len_ - 2, + emb_tables, + emb_indices, + nullptr, + nullptr, + emb_fw); + CHECK_EQ(r, 0); att_.Infer(ctx, sentense, emb_fw, att_out, l3_buffer, l3_size); } }; @@ -990,7 +1019,7 @@ class MMDNNMergeAll { fc2_.Init( fc2_w, fc2_w_max, fc2_b, fc2_n_, fc2_k_, xdnn::Activation_t::LINEAR); - int hbm_total_len = max_cap_l * cap_h_ * 4 + + int hbm_total_len = max_cap_l * cap_e_ * 2 + max_cap_l * cap_h_ * 2 + upper_bound_batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + fc1_k_ + fc1_n_ + fc2_n_); hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( @@ -1000,7 +1029,7 @@ class MMDNNMergeAll { void Infer(xdnn::Context* ctx, const MMDNNIdInfo& sentense, - const std::vector concat_2in1_x, + const std::vector concat_topk_x, const std::vector concat_7in1_x, lite::Tensor* out, float* l3_buffer = nullptr, @@ -1010,13 +1039,13 @@ class MMDNNMergeAll { float* topk_concat_out_fw = hbm_buffer_; int hbm_total_len = - cap_l * cap_h_ * 4 + + cap_l * cap_e_ * 2 + cap_l * cap_h_ * 2 + batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + fc1_k_ + fc1_n_ + fc2_n_); if (l3_size > 0 && l3_size >= hbm_total_len * sizeof(float)) { topk_concat_out_fw = l3_buffer; } - float* topk_concat_out_rv = topk_concat_out_fw + cap_l * cap_h_; - float* grnn_fw = topk_concat_out_rv + cap_l * cap_h_; + float* topk_concat_out_rv = topk_concat_out_fw + cap_l * cap_e_; + float* grnn_fw = topk_concat_out_rv + cap_l * cap_e_; float* grnn_rv = grnn_fw + cap_l * cap_h_; float* pool_fw = grnn_rv + cap_l * cap_h_; float* pool_rv = pool_fw + batch * cap_h_; @@ -1027,18 +1056,27 @@ class MMDNNMergeAll { // float* fc2_out = fc1_out + batch * fc1_n_; float* fc2_out = out->mutable_data(TARGET(kXPU)); - const int concat_widths[] = {static_cast(concat_2in1_x[0]->dims()[1]), - static_cast(concat_2in1_x[1]->dims()[1])}; - const float* concat_ptrs[] = {concat_2in1_x[0]->data(), - concat_2in1_x[1]->data()}; - xdnn::concat( - ctx, cap_l, concat_widths, 2, concat_ptrs, topk_concat_out_fw); - xdnn::sequence_reverse(ctx, - batch, - sentense.lod_32, - cap_e_, - topk_concat_out_fw, - topk_concat_out_rv); + std::vector concat_widths; + std::vector concat_ptrs; + for (const auto* t : concat_topk_x) { + concat_widths.push_back(static_cast(t->dims()[1])); + concat_ptrs.push_back(t->data()); + } + int r = 0; + r = xdnn::concat(ctx, + cap_l, + concat_widths.data(), + concat_widths.size(), + concat_ptrs.data(), + topk_concat_out_fw); + CHECK_EQ(r, 0); + r = xdnn::sequence_reverse(ctx, + batch, + sentense.lod_32, + cap_e_, + topk_concat_out_fw, + topk_concat_out_rv); + CHECK_EQ(r, 0); coverage_fw_.Infer(ctx, sentense, topk_concat_out_fw, @@ -1051,22 +1089,24 @@ class MMDNNMergeAll { grnn_rv, l3_buffer + hbm_total_len, l3_size - hbm_total_len * sizeof(float)); - xdnn::sequence_pooling_forward(ctx, - xdnn::Pooling_t::LAST, - batch, - sentense.lod_32, - cap_h_, - grnn_fw, - nullptr, - pool_fw); - xdnn::sequence_pooling_forward(ctx, - xdnn::Pooling_t::LAST, - batch, - sentense.lod_32, - cap_h_, - grnn_rv, - nullptr, - pool_rv); + r = xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::LAST, + batch, + sentense.lod_32, + cap_h_, + grnn_fw, + nullptr, + pool_fw); + CHECK_EQ(r, 0); + r = xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::LAST, + batch, + sentense.lod_32, + cap_h_, + grnn_rv, + nullptr, + pool_rv); + CHECK_EQ(r, 0); const int concat_widths_fc0[] = { static_cast(concat_7in1_x[0]->dims()[1]), @@ -1089,11 +1129,13 @@ class MMDNNMergeAll { const int concat_widths_fc1[] = {cap_h_, cap_h_, fc0_n_}; const float* concat_ptrs_fc1[] = {pool_fw, pool_rv, fc0_out}; - xdnn::concat( + r = xdnn::concat( ctx, batch, concat_widths_fc0, 7, concat_ptrs_fc0, fc0_in); + CHECK_EQ(r, 0); fc0_.Infer(ctx, fc0_in, batch, fc0_out); - xdnn::concat( + r = xdnn::concat( ctx, batch, concat_widths_fc1, 3, concat_ptrs_fc1, fc1_in); + CHECK_EQ(r, 0); fc1_.Infer(ctx, fc1_in, batch, fc1_out); fc2_.Infer(ctx, fc1_out, batch, fc2_out); } @@ -1111,14 +1153,12 @@ class XPUMmdnnBidEmbGrnnAttCompute private: MMDNNIdInfo id_; MMDNNBidEmbGrnnAtt compound_; - int upper_bound_batch_ = 40; - int upper_bound_seqlen_ = 512; }; void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() { auto& param = this->Param(); - id_.Init(upper_bound_batch_, upper_bound_seqlen_); + id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN); compound_.Init(param.emb_tbl, param.grnn_fw_wh, param.grnn_fw_wh_maxs, @@ -1131,8 +1171,8 @@ void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() { param.att_fc_w, param.att_fc_w_max, param.att_fc_b, - upper_bound_batch_, - upper_bound_seqlen_); + XPU_MAX_LOD_SIZE, + XPU_MAX_LOD_SEQ_LEN); } void XPUMmdnnBidEmbGrnnAttCompute::Run() { @@ -1157,6 +1197,76 @@ void XPUMmdnnBidEmbGrnnAttCompute::Run() { xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size); } +class XPUMmdnnBidEmbGrnnAttCompute2 + : public KernelLite { + public: + using param_t = operators::XPUMmdnnBidEmbGrnnAttParam2; + + void PrepareForRun() override; + + void Run() override; + + private: + MMDNNIdInfo id_; + MMDNNBidEmbGrnnAtt compound_; +}; + +void XPUMmdnnBidEmbGrnnAttCompute2::PrepareForRun() { + auto& param = this->Param(); + + id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN); + compound_.Init(param.emb_tbl, + param.grnn_fw_wh, + param.grnn_fw_wh_maxs, + param.grnn_fw_wi, + param.grnn_fw_wi_maxs, + param.grnn_rv_wh, + param.grnn_rv_wh_maxs, + param.grnn_rv_wi, + param.grnn_rv_wi_maxs, + param.att_fc_w, + param.att_fc_w_max, + param.att_fc_b, + XPU_MAX_LOD_SIZE, + XPU_MAX_LOD_SEQ_LEN); +} + +void XPUMmdnnBidEmbGrnnAttCompute2::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto* xpu_ctx = ctx.GetRawContext(); + + int batch = param.id0->lod()[0].size() - 1; + id_.Update(param.id0, param.id1); + compound_.Infer(ctx.GetRawContext(), + batch, + id_, + param.grnn_fw_pool_out, + param.grnn_rv_pool_out, + param.att_pool_out, + param.concat_3in1_out, + param.emb_fw_out, + reinterpret_cast( + reinterpret_cast(xpu_ctx->workspace_l3_ptr) + + xpu_ctx->used_l3_size), + xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size); + + int num = param.id0->numel(); + int embed_dim = param.emb_tbl->dims()[1]; + + // TODO(miaotianxiang): + int r = xdnn::embedding( + ctx.GetRawContext(), /* context */ + num, /* num */ + param.id0->data(), /* indices */ + embed_dim, /* embed_dim */ + param.emb_tbl->data(), /* table */ + param.emb0_out->mutable_data(TARGET(kXPU)), /* top */ + 128000 /* padding_idx */); + CHECK_EQ(r, 0); +} + class XPUMmdnnBidEmbAttCompute : public KernelLite { public: @@ -1169,20 +1279,18 @@ class XPUMmdnnBidEmbAttCompute private: MMDNNIdInfo id_; MMDNNEmbAtt compound_; - int upper_bound_batch_ = 40; - int upper_bound_seqlen_ = 512; }; void XPUMmdnnBidEmbAttCompute::PrepareForRun() { auto& param = this->Param(); - id_.Init(upper_bound_batch_, upper_bound_seqlen_); + id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN); compound_.Init(param.emb_tbl, param.att_fc_w, param.att_fc_w_max, param.att_fc_b, - upper_bound_batch_, - upper_bound_seqlen_); + XPU_MAX_LOD_SIZE, + XPU_MAX_LOD_SEQ_LEN); } void XPUMmdnnBidEmbAttCompute::Run() { @@ -1215,8 +1323,6 @@ class XPUMmdnnMatchConvTopkCompute private: MMDNNMatchConvTopk compound_; - int upper_bound_batch_ = 40; - int upper_bound_seqlen_ = 512; }; void XPUMmdnnMatchConvTopkCompute::PrepareForRun() { @@ -1228,8 +1334,9 @@ void XPUMmdnnMatchConvTopkCompute::PrepareForRun() { param.conv_w_max, param.dim_t, param.input_w->dims()[0], - upper_bound_batch_, - upper_bound_seqlen_, + param.output_channel, + XPU_MAX_LOD_SIZE, + XPU_MAX_LOD_SEQ_LEN, param.topks); } @@ -1261,14 +1368,12 @@ class XPUMmdnnMergeAllCompute private: MMDNNIdInfo id_; MMDNNMergeAll compound_; - int upper_bound_batch_ = 40; - int upper_bound_seqlen_ = 512; }; void XPUMmdnnMergeAllCompute::PrepareForRun() { auto& param = this->Param(); - id_.Init(upper_bound_batch_, upper_bound_seqlen_); + id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN); compound_.Init(param.grnn_fw_wh, param.grnn_fw_wh_maxs, param.grnn_fw_wi, @@ -1286,8 +1391,8 @@ void XPUMmdnnMergeAllCompute::PrepareForRun() { param.fc2_w, param.fc2_w_max, param.fc2_b, - upper_bound_batch_, - upper_bound_seqlen_); + XPU_MAX_LOD_SIZE, + XPU_MAX_LOD_SEQ_LEN); } void XPUMmdnnMergeAllCompute::Run() { @@ -1296,10 +1401,10 @@ void XPUMmdnnMergeAllCompute::Run() { auto* xpu_ctx = ctx.GetRawContext(); - id_.Update(param.concat_2in1_x[0], param.concat_2in1_x[1]); + id_.Update(param.concat_topk_x[0], param.concat_topk_x[1]); compound_.Infer(ctx.GetRawContext(), id_, - param.concat_2in1_x, + param.concat_topk_x, param.concat_7in1_x, param.out, reinterpret_cast( @@ -1335,6 +1440,29 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att, .BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att2, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUMmdnnBidEmbGrnnAttCompute2, + def) + .BindInput("id0", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindInput("id1", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindInput("emb_tbl", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_rv_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("att_fc_w", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("att_fc_b", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("emb0_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("grnn_fw_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("grnn_rv_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("att_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("concat_3in1_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); + REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_att, kXPU, kFloat, @@ -1371,7 +1499,7 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_merge_all, paddle::lite::kernels::xpu::XPUMmdnnMergeAllCompute, def) .BindInput("concat_7in1_x", {LiteType::GetTensorTy(TARGET(kXPU))}) - .BindInput("concat_2in1_x", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("concat_topk_x", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.h b/lite/kernels/xpu/__xpu__multi_encoder_compute.h index 71db4e6f44f9c36e4acdaf0a440463a61f4e3099..dbc2d785d42ad29dc1cfbe36f744b71662e48315 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.h +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/core/kernel.h" diff --git a/lite/kernels/xpu/__xpu__resnet50_compute.h b/lite/kernels/xpu/__xpu__resnet50_compute.h index 3d42f8b6f26edf615dba165b553b633673a4ae66..7ce8b1192ea9e85d83ddbeddc374378692866aa6 100644 --- a/lite/kernels/xpu/__xpu__resnet50_compute.h +++ b/lite/kernels/xpu/__xpu__resnet50_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/core/kernel.h" diff --git a/lite/kernels/xpu/__xpu__search_attention_compute.cc b/lite/kernels/xpu/__xpu__search_attention_compute.cc index 515be8935637d89d58db830f96f2ea439e7d7e68..7f02f566dfb01f2d8a57302e714f4f2cb3d4b786 100644 --- a/lite/kernels/xpu/__xpu__search_attention_compute.cc +++ b/lite/kernels/xpu/__xpu__search_attention_compute.cc @@ -22,16 +22,19 @@ namespace kernels { namespace xpu { void XPUMmdnnSearchAttentionCompute::PrepareForRun() { - offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - w_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float)); + offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + w_max_xpu_guard_ = + TargetWrapperXPU::MallocScratchPad(8 * sizeof(float), false /* use_l3 */); buffer_at_l3_guard_ = TargetWrapperXPU::MallocScratchPad( 5 * L3_SLOT_SIZE * sizeof(float), false /* use_l3 */); buffer_at_gm_guard_ = TargetWrapperXPU::MallocScratchPad( 5 * GM_SLOT_SIZE * sizeof(float), false /* use_l3 */); - offset_cpu.reset(new int[64]); - pad_begin_cpu.reset(new int[64]); + offset_cpu.reset(new int[XPU_MAX_LOD_SIZE]); + pad_begin_cpu.reset(new int[XPU_MAX_LOD_SIZE]); } void XPUMmdnnSearchAttentionCompute::Run() { @@ -72,18 +75,18 @@ void XPUMmdnnSearchAttentionCompute::Run() { } offset_cpu[batch] = offset[batch]; - xpu_memcpy(offset_xpu_guard_->addr_, - offset_cpu.get(), - offset.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(pad_begin_xpu_guard_->addr_, - pad_begin_cpu.get(), - batch * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(w_max_xpu_guard_->addr_, - maxs_cpu, - 8 * sizeof(float), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(offset_xpu_guard_->addr_, + offset_cpu.get(), + offset.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(pad_begin_xpu_guard_->addr_, + pad_begin_cpu.get(), + batch * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(w_max_xpu_guard_->addr_, + maxs_cpu, + 8 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); int* offset_xpu = reinterpret_cast(offset_xpu_guard_->addr_); int* pad_begin_xpu = reinterpret_cast(pad_begin_xpu_guard_->addr_); @@ -115,90 +118,99 @@ void XPUMmdnnSearchAttentionCompute::Run() { } const auto* bottom_data = X->data(); - xdnn::search_sequence_pad_depad(ctx.GetRawContext(), - const_cast(bottom_data), - group_padding_output, - offset_xpu, - max_seq, - batch, - dim1, - 0); // is_depad = 0 + int r = 0; + r = xdnn::search_sequence_pad_depad(ctx.GetRawContext(), + const_cast(bottom_data), + group_padding_output, + offset_xpu, + max_seq, + batch, + dim1, + 0); // is_depad = 0 + CHECK_EQ(r, 0); // do-findmax - xdnn::findmax(ctx.GetRawContext(), - group_padding_output, - batch * max_seq * dim1, - maxs_xpu); - xdnn::gemm_int16_maxptr( - ctx.GetRawContext(), - false, - true, // trans_a, trans_b - batch * max_seq, - dim1, - dim1, // m, n, k - 1.0f, - group_padding_output, - dim1, // alpha, data_a, lda - w_data, - dim1, - 0.0f, // data_b, ldb, beta - seq_fc_output, - dim1, - b_data, // data_c, ldc, bias - xdnn::Activation_t::LINEAR, - maxs_xpu, - maxs_xpu + 4, - nullptr); // max_a, max_b, max_c - xdnn::search_aligned_mat_mul(ctx.GetRawContext(), - 0, - 1, - batch, - max_seq, - max_seq, - dim1, - alpha0, - group_padding_output, - dim1, - seq_fc_output, - dim1, - batchgemm0_output, - max_seq); - xdnn::search_pad_mask(ctx.GetRawContext(), - batchgemm0_output, - attention_output, - pad_begin_xpu, - batch, - max_seq, - max_seq, - batch, - mask); - xdnn::softmax2d_forward(ctx.GetRawContext(), - attention_output, - seq_softmax_output, - batch * max_seq, - max_seq, - true); - xdnn::search_aligned_mat_mul(ctx.GetRawContext(), - 0, - 0, - batch, - max_seq, - dim1, - max_seq, - alpha1, - seq_softmax_output, - max_seq, - group_padding_output, - dim1, - batchgemm1_output, - dim1); - xdnn::search_sequence_pad_depad(ctx.GetRawContext(), - top_data, - batchgemm1_output, - offset_xpu, - max_seq, - batch, - dim1, - 1); // is_depad = 1 + r = xdnn::findmax(ctx.GetRawContext(), + group_padding_output, + batch * max_seq * dim1, + maxs_xpu); + CHECK_EQ(r, 0); + r = xdnn::gemm_int16_maxptr( + ctx.GetRawContext(), /* ctx */ + false, /* trans_a */ + true, /* trans_b */ + batch * max_seq, /* m */ + dim1, /* n */ + dim1, /* k */ + 1.0f, /* alpha */ + group_padding_output, /* data_a */ + dim1, /* lda */ + w_data, /* data_b */ + dim1, /* ldb */ + 0.0f, /* beta */ + seq_fc_output, /* data_c */ + dim1, /* ldc */ + b_data, /* bias */ + xdnn::Activation_t::LINEAR, /* act */ + maxs_xpu, /* max_a */ + maxs_xpu + 4, /* max_b */ + nullptr /* max_c */); + CHECK_EQ(r, 0); + r = xdnn::search_aligned_mat_mul(ctx.GetRawContext(), + 0, + 1, + batch, + max_seq, + max_seq, + dim1, + alpha0, + group_padding_output, + dim1, + seq_fc_output, + dim1, + batchgemm0_output, + max_seq); + CHECK_EQ(r, 0); + r = xdnn::search_pad_mask(ctx.GetRawContext(), + batchgemm0_output, + attention_output, + pad_begin_xpu, + batch, + max_seq, + max_seq, + batch, + mask); + CHECK_EQ(r, 0); + r = xdnn::softmax2d_forward(ctx.GetRawContext(), + attention_output, + seq_softmax_output, + batch * max_seq, + max_seq, + true); + CHECK_EQ(r, 0); + r = xdnn::search_aligned_mat_mul(ctx.GetRawContext(), + 0, + 0, + batch, + max_seq, + dim1, + max_seq, + alpha1, + seq_softmax_output, + max_seq, + group_padding_output, + dim1, + batchgemm1_output, + dim1); + CHECK_EQ(r, 0); + r = xdnn::search_sequence_pad_depad(ctx.GetRawContext(), + top_data, + batchgemm1_output, + offset_xpu, + max_seq, + batch, + dim1, + 1); // is_depad = 1 + CHECK_EQ(r, 0); } } // namespace xpu diff --git a/lite/kernels/xpu/activation_compute.h b/lite/kernels/xpu/activation_compute.h index e440bde4146a88929c52c20ff1038eb35be91d38..f2ad667886ac33191687b70aa7548050461545e7 100644 --- a/lite/kernels/xpu/activation_compute.h +++ b/lite/kernels/xpu/activation_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/batch_norm_compute.h b/lite/kernels/xpu/batch_norm_compute.h index 7b428476b96ca3b2b60c66df28b7f82e8f57bebc..f5244574cebab6b10bbd81af9c8303ffec9f0965 100644 --- a/lite/kernels/xpu/batch_norm_compute.h +++ b/lite/kernels/xpu/batch_norm_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/cast_compute.h b/lite/kernels/xpu/cast_compute.h index 8992c29732630a5bf0d9c092461569234257e3a9..efd4cbae8d2d708b25729f04f36bc22d1d909e11 100644 --- a/lite/kernels/xpu/cast_compute.h +++ b/lite/kernels/xpu/cast_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/conv_compute.h b/lite/kernels/xpu/conv_compute.h index b7631ce4e5773afe7cdd797a245c806b51d25c56..76159444c1861fad14b6ac4f0d32da626b3a8802 100644 --- a/lite/kernels/xpu/conv_compute.h +++ b/lite/kernels/xpu/conv_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/dropout_compute.h b/lite/kernels/xpu/dropout_compute.h index 0eaafb4f5555a163623402fee82d50bfa095b0b3..360450df537a68b9412d21db4e06dc74d6071ca6 100644 --- a/lite/kernels/xpu/dropout_compute.h +++ b/lite/kernels/xpu/dropout_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/elementwise_compute.h b/lite/kernels/xpu/elementwise_compute.h index 863ee3c643f9c431dacd057e251941914b1dd1c5..d910b9293e74428c426d9505245bc5958fc9df3a 100644 --- a/lite/kernels/xpu/elementwise_compute.h +++ b/lite/kernels/xpu/elementwise_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/layer_norm_compute.h b/lite/kernels/xpu/layer_norm_compute.h index 5d2df37795811ef8027e12b25139f2b7091cceed..9eeb5924c512fcfbf8825a9ff775378dfe4d6d4c 100644 --- a/lite/kernels/xpu/layer_norm_compute.h +++ b/lite/kernels/xpu/layer_norm_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/lookup_table_compute.cc b/lite/kernels/xpu/lookup_table_compute.cc index 568d303adefaa06bb8665b4cc92d4a949419d587..4256687fa8c17c7fe36e91ff727d52eb1047646f 100644 --- a/lite/kernels/xpu/lookup_table_compute.cc +++ b/lite/kernels/xpu/lookup_table_compute.cc @@ -29,12 +29,13 @@ void LookupTableCompute::Run() { int embed_dim = param.W->dims()[1]; int r = xdnn::embedding( - ctx.GetRawContext(), /* context */ - num, /* num */ - param.Ids->data(), /* indices */ - embed_dim, /* embed_dim */ - param.W->data(), /* table */ - param.Out->mutable_data(TARGET(kXPU)) /* top */); + ctx.GetRawContext(), /* context */ + num, /* num */ + param.Ids->data(), /* indices */ + embed_dim, /* embed_dim */ + param.W->data(), /* table */ + param.Out->mutable_data(TARGET(kXPU)), /* top */ + param.padding_idx /* padding_idx */); CHECK_EQ(r, 0); } diff --git a/lite/kernels/xpu/lookup_table_compute.h b/lite/kernels/xpu/lookup_table_compute.h index 2ba1afc869cf9c3a49ab1ad29c66c6c89ba87d19..7a43f5244e5d514a1644aac0437951af35bb7767 100644 --- a/lite/kernels/xpu/lookup_table_compute.h +++ b/lite/kernels/xpu/lookup_table_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/match_matrix_tensor_compute.cc b/lite/kernels/xpu/match_matrix_tensor_compute.cc index 3c4e896d23add6df99a7b66a830dc526dc808e95..c3ee547ccce56cd16401e4aca465e64d99a26185 100644 --- a/lite/kernels/xpu/match_matrix_tensor_compute.cc +++ b/lite/kernels/xpu/match_matrix_tensor_compute.cc @@ -23,12 +23,15 @@ namespace kernels { namespace xpu { void MatchMatrixTensorCompute::PrepareForRun() { - wx_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - offset_r_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - - offset_l_cpu.reset(new int[64]); - offset_r_cpu.reset(new int[64]); + wx_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + offset_r_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + + offset_l_cpu.reset(new int[XPU_MAX_LOD_SIZE]); + offset_r_cpu.reset(new int[XPU_MAX_LOD_SIZE]); } void MatchMatrixTensorCompute::Run() { @@ -76,25 +79,25 @@ void MatchMatrixTensorCompute::Run() { int* offset_r_xpu = reinterpret_cast(offset_r_xpu_guard_->addr_); int r = xdnn::gemm_int16_tmp_api( - ctx.GetRawContext(), /* ctx */ - false, - false, /* trans_a, trans_b */ - x->dims()[0], - dim_t * dim_in, - dim_in, /* m, n, k */ - 1.0f, - bottom_l_data, - dim_in, /* alpha, data_a, lda */ - w_data, - dim_t * dim_in, - 0.0f, /* data_b, ldb, beta */ - bottom_l_trans_data, - dim_t * dim_in, /* data_c, ldc */ - nullptr, /* bias */ - xdnn::Activation_t::LINEAR, - 0.0f, - w_max, - wx_max /* max_a, max_b, max_c */); + ctx.GetRawContext(), /* ctx */ + false, /* trans_a */ + false, /* trans_b */ + x->dims()[0], /* m */ + dim_t * dim_in, /* n */ + dim_in, /* k */ + 1.0f, /* alpha */ + bottom_l_data, /* data_a */ + dim_in, /* lda */ + w_data, /* data_b */ + dim_t * dim_in, /* ldb */ + 0.0f, /* beta */ + bottom_l_trans_data, /* data_c */ + dim_t * dim_in, /* ldc */ + nullptr, /* bias */ + xdnn::Activation_t::LINEAR, /* act */ + 0.0f, /* max_a */ + w_max, /* max_b */ + wx_max /* max_c */); CHECK_EQ(r, 0); int max_width = 0; @@ -110,14 +113,14 @@ void MatchMatrixTensorCompute::Run() { max_width = offset_r_cpu[i] - offset_r_cpu[i - 1]; } } - xpu_memcpy(offset_l_xpu, - offset_l_cpu.get(), - offset_l.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(offset_r_xpu, - offset_r_cpu.get(), - offset_r.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(offset_l_xpu, + offset_l_cpu.get(), + offset_l.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(offset_r_xpu, + offset_r_cpu.get(), + offset_r.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); r = xdnn::match_matrix_tensor(ctx.GetRawContext(), batch_size, diff --git a/lite/kernels/xpu/matmul_compute.h b/lite/kernels/xpu/matmul_compute.h index aca3cbc603eff490ae19fd2546352adca3c1a7cf..0fef2086e294fa5cd79e49adeb6b136f484a1efd 100644 --- a/lite/kernels/xpu/matmul_compute.h +++ b/lite/kernels/xpu/matmul_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/mul_compute.h b/lite/kernels/xpu/mul_compute.h index bb2778c0e73189b11135395b42655e0250bbfd0a..3c91384b726a4d43c6a38e96d143657c12dadd8a 100644 --- a/lite/kernels/xpu/mul_compute.h +++ b/lite/kernels/xpu/mul_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/pool_compute.h b/lite/kernels/xpu/pool_compute.h index 5648554c41c76396184b7dc536f8c8628cbf23e4..39e14f04a8c41bc057ac5733d881ba713c0883b2 100644 --- a/lite/kernels/xpu/pool_compute.h +++ b/lite/kernels/xpu/pool_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/scale_compute.h b/lite/kernels/xpu/scale_compute.h index 6989b0f0f31e54a63dac2f7c2090dc676e31acfb..5a84fe26a0d409dcd979ca7c26128775a4f64df2 100644 --- a/lite/kernels/xpu/scale_compute.h +++ b/lite/kernels/xpu/scale_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/search_fc_compute.cc b/lite/kernels/xpu/search_fc_compute.cc index 79f4c2d0d809ea9848fb383863d0f9dd2ec5a2ae..52a9999b468564d81288ce494f575a8d1d46e4fc 100644 --- a/lite/kernels/xpu/search_fc_compute.cc +++ b/lite/kernels/xpu/search_fc_compute.cc @@ -23,7 +23,8 @@ namespace kernels { namespace xpu { void SearchFcCompute::PrepareForRun() { - maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(float)); + maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(float), false /* use_l3 */); } void SearchFcCompute::Run() { @@ -59,34 +60,34 @@ void SearchFcCompute::Run() { float* maxs_xpu = reinterpret_cast(maxs_xpu_guard_->addr_); float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, w_max, 0.0f, 0.0f, 0.0f}; - xpu_memcpy(maxs_xpu, - &maxs_cpu[0], - 8 * sizeof(float), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(maxs_xpu, + &maxs_cpu[0], + 8 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); int r = xdnn::findmax( ctx.GetRawContext(), bottom_data, batch * _in, maxs_xpu); CHECK_EQ(r, 0); r = xdnn::gemm_int16_maxptr( ctx.GetRawContext(), /* ctx */ - false, - true, /*trans_a, trans_b*/ - batch, - _out, - _in, /*m, n, k*/ - 1.0f, - bottom_data, - _in, /*alpha, data_a, lda*/ - weights, - _in, - 0.0f, /*data_b, ldb, beta*/ - top_data, - _out, - bias_data, /* data_c, ldc, bias*/ - act, - maxs_xpu, - maxs_xpu + 4, - nullptr /*act, max_a, max_b, max_c*/); + false, /* trans_a */ + true, /* trans_b */ + batch, /* m */ + _out, /* n */ + _in, /* k */ + 1.0f, /* alpha */ + bottom_data, /* data_a */ + _in, /* lda */ + weights, /* data_b */ + _in, /* ldb */ + 0.0f, /* beta */ + top_data, /* data_c */ + _out, /* ldc */ + bias_data, /* bias */ + act, /* act */ + maxs_xpu, /* max_a */ + maxs_xpu + 4, /* max_b */ + nullptr /* max_c */); CHECK_EQ(r, 0); } diff --git a/lite/kernels/xpu/search_grnn_compute.cc b/lite/kernels/xpu/search_grnn_compute.cc index 1c19f58da1b5deaa3d74791561494f13b681cf3a..d4e2e4a9969149b0d2f7f2b75c195d1b3a5fda5c 100644 --- a/lite/kernels/xpu/search_grnn_compute.cc +++ b/lite/kernels/xpu/search_grnn_compute.cc @@ -24,13 +24,16 @@ namespace kernels { namespace xpu { void SearchGrnnCompute::PrepareForRun() { - offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int)); - maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(16 * sizeof(float)); + offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SEQ_LEN * sizeof(int), false /* use_l3 */); + maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(16 * sizeof(float), + false /* use_l3 */); - idx_sorted_by_width_data_cpu.reset(new int[64]); - offset_cpu.reset(new int[64]); - new_offset_cpu.reset(new int[256]); + idx_sorted_by_width_data_cpu.reset(new int[XPU_MAX_LOD_SIZE]); + offset_cpu.reset(new int[XPU_MAX_LOD_SIZE]); + new_offset_cpu.reset(new int[XPU_MAX_LOD_SEQ_LEN]); } void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param, @@ -96,10 +99,10 @@ void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param, layout_input->Resize({dim0, dim1}); } - xpu_memcpy(idx_sorted_by_width->mutable_data(TARGET(kXPU)), - idx_sorted_by_width_data_cpu.get(), - idx_sorted_by_width->numel() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(idx_sorted_by_width->mutable_data(TARGET(kXPU)), + idx_sorted_by_width_data_cpu.get(), + idx_sorted_by_width->numel() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); } void SearchGrnnCompute::Run() { @@ -156,14 +159,14 @@ void SearchGrnnCompute::Run() { for (size_t i = 0; i < new_offset.size(); ++i) { new_offset_cpu[i] = new_offset[i]; } - xpu_memcpy(offset_xpu, - offset_cpu.get(), - offset.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(new_offset_xpu, - new_offset_cpu.get(), - new_offset.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(offset_xpu, + offset_cpu.get(), + offset.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(new_offset_xpu, + new_offset_cpu.get(), + new_offset.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); int r = xdnn::search_seq2batch(ctx.GetRawContext(), batch, @@ -200,10 +203,10 @@ void SearchGrnnCompute::Run() { 0.0f, 0.0f, 0.0f}; - xpu_memcpy(maxs_xpu, - maxs_cpu, - 16 * sizeof(float), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(maxs_xpu, + maxs_cpu, + 16 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); r = xdnn::findmax( ctx.GetRawContext(), new_emb, cap_l * cap_e, maxs_xpu); CHECK_EQ(r, 0); diff --git a/lite/kernels/xpu/sequence_arithmetic_compute.cc b/lite/kernels/xpu/sequence_arithmetic_compute.cc index 226c615dba57ae381ed2457e588c5df32f25e04b..e1b9866123395b2d7867154c3b398adae670ed97 100644 --- a/lite/kernels/xpu/sequence_arithmetic_compute.cc +++ b/lite/kernels/xpu/sequence_arithmetic_compute.cc @@ -37,44 +37,54 @@ void SequenceArithmeticCompute::Run() { const auto* bottom_data1 = bottom1->data(); auto* top_data = top->mutable_data(TARGET(kXPU)); + int r = 0; switch (op_type) { case 1: // addition: top[0] = bottom[0] + bottom[1] if (len1 > len2) { - xdnn::elementwise_add( + r = xdnn::elementwise_add( ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); - xdnn::memcpy_device(ctx.GetRawContext(), - &top_data[len2], - &bottom_data0[len2], - (len1 - len2) * sizeof(float)); + CHECK_EQ(r, 0); + r = xdnn::memcpy_device(ctx.GetRawContext(), + &top_data[len2], + &bottom_data0[len2], + (len1 - len2) * sizeof(float)); + CHECK_EQ(r, 0); } else { - xdnn::elementwise_add( + r = xdnn::elementwise_add( ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); + CHECK_EQ(r, 0); } break; case 2: // substraction: top[0] = bottom[0] - bottom[1] if (len1 > len2) { - xdnn::elementwise_sub( + r = xdnn::elementwise_sub( ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); - xdnn::memcpy_device(ctx.GetRawContext(), - &top_data[len2], - &bottom_data0[len2], - (len1 - len2) * sizeof(float)); + CHECK_EQ(r, 0); + r = xdnn::memcpy_device(ctx.GetRawContext(), + &top_data[len2], + &bottom_data0[len2], + (len1 - len2) * sizeof(float)); + CHECK_EQ(r, 0); } else { - xdnn::elementwise_sub( + r = xdnn::elementwise_sub( ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); + CHECK_EQ(r, 0); } break; case 3: // multiplication: top[0] = bottom[0] * bottom[1] if (len1 > len2) { - xdnn::elementwise_mul( + r = xdnn::elementwise_mul( ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); - xdnn::memcpy_device(ctx.GetRawContext(), - &top_data[len2], - &bottom_data0[len2], - (len1 - len2) * sizeof(float)); + CHECK_EQ(r, 0); + r = xdnn::memcpy_device(ctx.GetRawContext(), + &top_data[len2], + &bottom_data0[len2], + (len1 - len2) * sizeof(float)); + CHECK_EQ(r, 0); } else { - xdnn::elementwise_mul( + r = xdnn::elementwise_mul( ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); + CHECK_EQ(r, 0); } break; default: diff --git a/lite/kernels/xpu/sequence_concat_compute.cc b/lite/kernels/xpu/sequence_concat_compute.cc index fd7f5999a6ccb18efbcb0e96b50f2b31884fc21c..349fdbad2a89300703c820588b4647bfba77ece5 100644 --- a/lite/kernels/xpu/sequence_concat_compute.cc +++ b/lite/kernels/xpu/sequence_concat_compute.cc @@ -23,11 +23,13 @@ namespace kernels { namespace xpu { void SequenceConcatCompute::PrepareForRun() { - lod0_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - lod1_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + lod0_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + lod1_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); - lod0_cpu.reset(new int[64]); - lod1_cpu.reset(new int[64]); + lod0_cpu.reset(new int[XPU_MAX_LOD_SIZE]); + lod1_cpu.reset(new int[XPU_MAX_LOD_SIZE]); } template @@ -106,14 +108,14 @@ void SequenceConcatCompute::Run() { for (int i = 0; i < lod1.size(); ++i) { lod1_cpu[i] = lod1[i]; } - xpu_memcpy(lod0_xpu, - lod0_cpu.get(), - lod0.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(lod1_xpu, - lod1_cpu.get(), - lod1.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(lod0_xpu, + lod0_cpu.get(), + lod0.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(lod1_xpu, + lod1_cpu.get(), + lod1.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); int r = xdnn::sequence_concat(ctx.GetRawContext(), xs[0]->data(), diff --git a/lite/kernels/xpu/sequence_pool_compute.cc b/lite/kernels/xpu/sequence_pool_compute.cc index 81d9b5873c3c42afe94acdd8eb5a292326b7a7b6..f8e71639b7f4c67f7e60103a42766a4d32026bc1 100644 --- a/lite/kernels/xpu/sequence_pool_compute.cc +++ b/lite/kernels/xpu/sequence_pool_compute.cc @@ -23,8 +23,9 @@ namespace kernels { namespace xpu { void XPUSequencePoolCompute::PrepareForRun() { - lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - lod_cpu.reset(new int[64]); + lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]); } void XPUSequencePoolCompute::Run() { @@ -55,10 +56,10 @@ void XPUSequencePoolCompute::Run() { lod_cpu[i] = in_lod[i]; } int* lod_xpu = reinterpret_cast(lod_xpu_guard_->addr_); - xpu_memcpy(lod_xpu, - lod_cpu.get(), - in_lod.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(lod_xpu, + lod_cpu.get(), + in_lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); int r = xdnn::sequence_pooling_forward(ctx.GetRawContext(), diff --git a/lite/kernels/xpu/sequence_reverse_compute.cc b/lite/kernels/xpu/sequence_reverse_compute.cc index 11e4b80570c19fa90e7846d18a88f966f9a003b7..bb3f37890b644a660c594fb0fd6eea332b90b8d6 100644 --- a/lite/kernels/xpu/sequence_reverse_compute.cc +++ b/lite/kernels/xpu/sequence_reverse_compute.cc @@ -23,8 +23,9 @@ namespace xpu { template void SequenceReverseCompute::PrepareForRun() { - lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - lod_cpu.reset(new int[64]); + lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]); } template @@ -58,10 +59,10 @@ void SequenceReverseCompute::Run() { lod_cpu[i] = lod[i]; } int* lod_xpu = reinterpret_cast(lod_xpu_guard_->addr_); - xpu_memcpy(lod_xpu, - lod_cpu.get(), - lod.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(lod_xpu, + lod_cpu.get(), + lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); int r = xdnn::sequence_reverse(ctx.GetRawContext(), batch_size, diff --git a/lite/kernels/xpu/sequence_topk_avg_pooling_compute.cc b/lite/kernels/xpu/sequence_topk_avg_pooling_compute.cc index 54c74211f9738995a8191c77e879a85762d71b3b..4e8485e2999b29dfb487d0c7c632fcfa7a9a3d00 100644 --- a/lite/kernels/xpu/sequence_topk_avg_pooling_compute.cc +++ b/lite/kernels/xpu/sequence_topk_avg_pooling_compute.cc @@ -23,10 +23,11 @@ namespace kernels { namespace xpu { void SequenceTopkAvgPoolingCompute::PrepareForRun() { - lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int)); - in_lod_cpu.reset(new int[64]); - row_lod_cpu.reset(new int[64]); - col_lod_cpu.reset(new int[64]); + lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + 4 * XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + in_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]); + row_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]); + col_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]); } void SequenceTopkAvgPoolingCompute::Run() { @@ -81,22 +82,22 @@ void SequenceTopkAvgPoolingCompute::Run() { for (int i = 0; i < col_lod.size(); ++i) { col_lod_cpu[i] = col_lod[i]; } - xpu_memcpy(in_lod_xpu, - in_lod_cpu.get(), - in_lod.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(row_lod_xpu, - row_lod_cpu.get(), - row_lod.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(col_lod_xpu, - col_lod_cpu.get(), - col_lod.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(topks_xpu, - topks.data(), - topks.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(in_lod_xpu, + in_lod_cpu.get(), + in_lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(row_lod_xpu, + row_lod_cpu.get(), + row_lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(col_lod_xpu, + col_lod_cpu.get(), + col_lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(topks_xpu, + topks.data(), + topks.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); int r = xdnn::sequence_topk_avg_pooling(ctx.GetRawContext(), in_data, diff --git a/lite/kernels/xpu/softmax_compute.h b/lite/kernels/xpu/softmax_compute.h index e807f38a2ea3c9645b78340ac4dc87d1984c40f7..a3d282588776b7d64bc856adf92685c8524af035 100644 --- a/lite/kernels/xpu/softmax_compute.h +++ b/lite/kernels/xpu/softmax_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "lite/core/kernel.h" namespace paddle { diff --git a/lite/kernels/xpu/stack_compute.cc b/lite/kernels/xpu/stack_compute.cc index 90a6c70b49f39ce744f2a03eec41d79ddc768a19..156162923ceeb4abed466164b11672715f813fd7 100644 --- a/lite/kernels/xpu/stack_compute.cc +++ b/lite/kernels/xpu/stack_compute.cc @@ -25,9 +25,8 @@ void StackCompute::PrepareForRun() { auto& param = this->Param(); int n = param.X.size(); - void* x_ptr = nullptr; - xpu_malloc(&x_ptr, n * 8 /* sizeof(__global__ float*) */); - x_ptr_guard_.reset(x_ptr); + x_ptr_guard_ = TargetWrapperXPU::MallocScratchPad( + n * 8 /* sizeof(__global__ float*) */, false /* use_l3 */); x_ptr_cpu_.reserve(n); } @@ -47,14 +46,15 @@ void StackCompute::Run() { for (int i = 0; i < n; ++i) { x_ptr_cpu_[i] = param.X[i]->data(); } - xpu_memcpy(x_ptr_guard_.get(), &x_ptr_cpu_[0], n * 8, XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy( + x_ptr_guard_->addr_, &x_ptr_cpu_[0], n * 8, XPU_HOST_TO_DEVICE)); int r = xdnn::stack_forward( ctx.GetRawContext(), /* context */ height, /* height */ width, /* width */ n, /* n */ - x_ptr_guard_.get(), /* x_ptr */ + x_ptr_guard_->addr_, /* x_ptr */ param.Out->mutable_data(TARGET(kXPU)) /* out */); CHECK_EQ(r, 0); } diff --git a/lite/kernels/xpu/stack_compute.h b/lite/kernels/xpu/stack_compute.h index 1ba1d92dc9479cfd00c5e154df7b5476ffd9976c..7618e2a147b862aee097a42b36721d520ad6012c 100644 --- a/lite/kernels/xpu/stack_compute.h +++ b/lite/kernels/xpu/stack_compute.h @@ -14,10 +14,9 @@ #pragma once -#include #include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard #include "lite/core/kernel.h" -#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter namespace paddle { namespace lite { @@ -35,7 +34,7 @@ class StackCompute : public KernelLite { virtual ~StackCompute() = default; private: - std::unique_ptr x_ptr_guard_; + XPUScratchPadGuard x_ptr_guard_; std::vector x_ptr_cpu_; }; diff --git a/lite/kernels/xpu/subgraph_compute.cc b/lite/kernels/xpu/subgraph_compute.cc index 981922f8eacab57da4638e1fdcdd3df72465b379..ac301108386e2da43b2efc372b96531df8d55523 100644 --- a/lite/kernels/xpu/subgraph_compute.cc +++ b/lite/kernels/xpu/subgraph_compute.cc @@ -53,10 +53,11 @@ bool SubgraphEngine::BuildDeviceProgram() { // IR graph subgraph::xpu::Graph graph; const auto& bridges = subgraph::Registry::Instance(); - if (origin_program_.empty()) { + if (!origin_program_) { BuildOriginProgram(); } - for (auto& inst : origin_program_) { + const auto& insts = origin_program_->instructions(kRootBlockIdx); + for (auto& inst : insts) { auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); @@ -123,7 +124,7 @@ bool SubgraphEngine::BuildDeviceProgram() { auto node = graph.Get(device_inames_[i]); auto precision = node->precision(); auto layout = node->layout(); - origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]); + origin_itensors_[i] = exec_scope_->FindMutableTensor(device_inames_[i]); CHECK(origin_itensors_[i]); origin_idims_[i] = origin_itensors_[i]->dims(); VLOG(3) << "[XPU] Inputs[" << i << "] name: " << device_inames_[i] @@ -147,7 +148,7 @@ bool SubgraphEngine::BuildDeviceProgram() { auto node = graph.Get(device_onames_[i]); auto precision = node->precision(); auto layout = node->layout(); - origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]); + origin_otensors_[i] = exec_scope_->FindMutableTensor(device_onames_[i]); CHECK(origin_otensors_[i]); origin_odims_[i] = origin_otensors_[i]->dims(); VLOG(3) << "[XPU] Outputs[" << i << "] name: " << device_onames_[i] @@ -220,11 +221,11 @@ bool SubgraphEngine::LaunchDeviceProgram() { void SubgraphCompute::PrepareForRun() { auto& param = this->Param(); engine_.reset(new SubgraphEngine(ctx_.get(), - param.sub_block_idx, - param.sub_block_desc, + param.block_idx, + param.program_desc, + param.exec_scope, param.input_data_names, - param.output_data_names, - param.scope)); + param.output_data_names)); CHECK(engine_); } diff --git a/lite/kernels/xpu/subgraph_compute.h b/lite/kernels/xpu/subgraph_compute.h index f09a06a85d5382c72e9efb20cede8bea1922f2da..25ffa721572ce05b0652d56659f3db12903c589b 100644 --- a/lite/kernels/xpu/subgraph_compute.h +++ b/lite/kernels/xpu/subgraph_compute.h @@ -31,12 +31,16 @@ class SubgraphEngine : public subgraph::Engine { public: SubgraphEngine(KernelContext *ctx, int block_idx, - cpp::BlockDesc *block_desc, + const std::shared_ptr &program_desc, + Scope *exec_scope, const std::vector &input_names, - const std::vector &output_names, - Scope *scope) - : subgraph::Engine( - ctx, block_idx, block_desc, input_names, output_names, scope) {} + const std::vector &output_names) + : subgraph::Engine(ctx, + block_idx, + program_desc, + exec_scope, + input_names, + output_names) {} protected: bool PrepareWorkspaceForDeviceProgram() override; diff --git a/lite/kernels/xpu/var_conv_2d_compute.cc b/lite/kernels/xpu/var_conv_2d_compute.cc index b573c810922db98e901c9f9a1953116f3fdfc657..b73581951f46a5f3cdbaf64cf732b1909805d27d 100644 --- a/lite/kernels/xpu/var_conv_2d_compute.cc +++ b/lite/kernels/xpu/var_conv_2d_compute.cc @@ -23,10 +23,12 @@ namespace kernels { namespace xpu { void VarConv2DCompute::PrepareForRun() { - offset_x_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - offset_y_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); - offset_x_cpu.reset(new int[64]); - offset_y_cpu.reset(new int[64]); + offset_x_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + offset_y_xpu_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */); + offset_x_cpu.reset(new int[XPU_MAX_LOD_SIZE]); + offset_y_cpu.reset(new int[XPU_MAX_LOD_SIZE]); } void VarConv2DCompute::Run() { @@ -94,14 +96,14 @@ void VarConv2DCompute::Run() { offset_x_cpu[i] = offset_x[i]; offset_y_cpu[i] = offset_y[i]; } - xpu_memcpy(offset_x_xpu, - offset_x_cpu.get(), - (batch + 1) * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - xpu_memcpy(offset_y_xpu, - offset_y_cpu.get(), - (batch + 1) * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + XPU_CALL(xpu_memcpy(offset_x_xpu, + offset_x_cpu.get(), + (batch + 1) * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(offset_y_xpu, + offset_y_cpu.get(), + (batch + 1) * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); int r = xdnn::search_varconv(ctx.GetRawContext(), batch, diff --git a/lite/model_parser/base/apis.h b/lite/model_parser/base/apis.h index 2ad6ff47ee17fcdfab335b3a6f87229811d971ae..fa3449017c902479a7f6ad37ef73b3a316f585cc 100644 --- a/lite/model_parser/base/apis.h +++ b/lite/model_parser/base/apis.h @@ -17,6 +17,7 @@ #include "lite/model_parser/base/block_desc.h" #include "lite/model_parser/base/op_desc.h" #include "lite/model_parser/base/program_desc.h" +#include "lite/model_parser/base/proto_desc.h" #include "lite/model_parser/base/traits.h" #include "lite/model_parser/base/var_desc.h" #include "lite/utils/all.h" diff --git a/lite/model_parser/base/proto_desc.h b/lite/model_parser/base/proto_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..0f62ef6e43883fd41c509795d1e4f695fdbb8910 --- /dev/null +++ b/lite/model_parser/base/proto_desc.h @@ -0,0 +1,26 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { + +// The Index of first Block in Program. also called root block. +constexpr int kRootBlockIdx = 0; +// The Parent Index of root Block, this block does not exist. +constexpr int kNoneBlockIdx = -1; + +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/base/vector_view.h b/lite/model_parser/base/vector_view.h index adec1933a2f40face415f610c9ccf2e9f275020c..e4149d9c5acae83472904a86c47659355972855e 100644 --- a/lite/model_parser/base/vector_view.h +++ b/lite/model_parser/base/vector_view.h @@ -57,21 +57,35 @@ class VectorView { public: typedef vector_view::VectorTraits Traits; explicit VectorView(typename Traits::vector_type const* cvec) { - CHECK(cvec); cvec_ = cvec; } typename Traits::subscript_return_type operator[](size_t i) const { return cvec_->operator[](i); } - typename Traits::const_iterator begin() const { return cvec_->begin(); } - typename Traits::const_iterator end() const { return cvec_->end(); } - size_t size() const { return cvec_->size(); } + typename Traits::const_iterator begin() const { + if (!cvec_) { + return typename Traits::const_iterator(); + } + return cvec_->begin(); + } + typename Traits::const_iterator end() const { + if (!cvec_) { + return typename Traits::const_iterator(); + } + return cvec_->end(); + } + size_t size() const { + if (!cvec_) { + return 0; + } + return cvec_->size(); + } operator std::vector() const { VLOG(5) << "Copying elements out of VectorView will damage performance."; std::vector tmp; - tmp.reserve(cvec_->size()); - for (auto val : *cvec_) { - tmp.push_back(val); + tmp.reserve(size()); + for (size_t i = 0; i < size(); ++i) { + tmp.push_back(cvec_->operator[](i)); } return tmp; } diff --git a/lite/model_parser/compatible_pb.cc b/lite/model_parser/compatible_pb.cc index b8db89230d56e22a361cc4972382d74b8d6f08fd..8bfeb419e51b01ae008959ac5af3e9752834b1ab 100644 --- a/lite/model_parser/compatible_pb.cc +++ b/lite/model_parser/compatible_pb.cc @@ -234,7 +234,7 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { template <> \ void TransformBlockDescCppToAny(const cpp::T &cpp_desc, \ NT::T *any_desc) { \ - auto desc = cpp_desc; \ + const cpp::T &desc = cpp_desc; \ any_desc->SetIdx(desc.Idx()); \ any_desc->SetParentIdx(desc.ParentIdx()); \ any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ diff --git a/lite/model_parser/flatbuffers/io.cc b/lite/model_parser/flatbuffers/io.cc index 28fa32398cfe76075c1a429f9f1d348842465dfc..ef8e9afaefe94d72113299050f16077a09f6c6cf 100644 --- a/lite/model_parser/flatbuffers/io.cc +++ b/lite/model_parser/flatbuffers/io.cc @@ -15,20 +15,21 @@ #include "lite/model_parser/flatbuffers/io.h" #include #include +#include namespace paddle { namespace lite { namespace fbs { void LoadModel(const std::string& path, ProgramDesc* prog) { + CHECK(prog); FILE* file = fopen(path.c_str(), "rb"); fseek(file, 0, SEEK_END); - int64_t size = ftell(file); + int64_t length = ftell(file); rewind(file); - char* data = new char[size]; - size = fread(data, 1, size, file); + std::vector buf(length); + CHECK(fread(buf.data(), 1, length, file)); fclose(file); - std::unique_ptr buf(data); prog->Init(std::move(buf)); } diff --git a/lite/model_parser/flatbuffers/op_desc.h b/lite/model_parser/flatbuffers/op_desc.h index e133ffbc27dce1a8c00eed82cc6d4fca76a8564d..450aa49fa13b676b33bef8490c65061dc504431d 100644 --- a/lite/model_parser/flatbuffers/op_desc.h +++ b/lite/model_parser/flatbuffers/op_desc.h @@ -62,7 +62,7 @@ class OpDesc : public OpDescAPI { std::vector Output(const std::string& param) const override { const auto& var = desc_->outputs()->LookupByKey(param.c_str()); std::vector args_vec; - if (var->arguments()) { + if (var && var->arguments()) { args_vec.reserve(var->arguments()->size()); for (const auto& out : *var->arguments()) { args_vec.push_back(out->str()); @@ -169,8 +169,7 @@ class OpDesc : public OpDescAPI { } bool HasOutput(const std::string& param) const { - NotImplemented(); - return false; + return !Output(param).empty(); } const std::map& attrs() const { diff --git a/lite/model_parser/flatbuffers/program_desc.h b/lite/model_parser/flatbuffers/program_desc.h index c651d9dc0671aced942bb28466e829a40226c2ba..55218eef5b4037d13b2f45db6de6b94cb39d994e 100644 --- a/lite/model_parser/flatbuffers/program_desc.h +++ b/lite/model_parser/flatbuffers/program_desc.h @@ -29,16 +29,25 @@ namespace fbs { class ProgramDesc : public ProgramDescAPI { public: ProgramDesc() = default; - explicit ProgramDesc(std::unique_ptr buf) { - Init(std::move(buf)); + explicit ProgramDesc(const std::vector& buf) { Init(buf); } + explicit ProgramDesc(std::vector&& buf) { + Init(std::forward>(buf)); } - size_t BlocksSize() const override { return desc_->blocks()->size(); } + void Init(const std::vector& buf) { + CHECK(buf.data()); + buf_ = buf; + InitProgramDesc(); + } - void Init(std::unique_ptr buf) { - CHECK(buf.get() != nullptr); + void Init(std::vector&& buf) { + CHECK(buf.data()); buf_ = std::move(buf); - desc_ = proto::GetProgramDesc(buf_.get()); + InitProgramDesc(); + } + + void InitProgramDesc() { + desc_ = proto::GetProgramDesc(buf_.data()); blocks_.reserve(BlocksSize()); for (size_t idx = 0; idx < BlocksSize(); ++idx) { blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx))); @@ -46,12 +55,12 @@ class ProgramDesc : public ProgramDescAPI { } void CopyFrom(const ProgramDesc& other) { - size_t length = strlen(static_cast(other.raw_buf())); - std::unique_ptr buf(new char[length]); - memcpy(buf.get(), other.raw_buf(), length); - Init(std::move(buf)); + buf_ = other.buf(); + Init(buf_); } + size_t BlocksSize() const override { return desc_->blocks()->size(); } + template T const* GetBlock(int32_t idx) const; @@ -72,11 +81,11 @@ class ProgramDesc : public ProgramDescAPI { proto::ProgramDesc const* raw_desc() const { return desc_; } - const void* raw_buf() const { return buf_.get(); } + const std::vector& buf() const { return buf_; } private: proto::ProgramDesc const* desc_; - std::unique_ptr buf_; + std::vector buf_; std::vector blocks_; private: diff --git a/lite/model_parser/flatbuffers/vector_view.h b/lite/model_parser/flatbuffers/vector_view.h index 1cc890e98d2a85b3113fcf49a68701595e63964e..bb1331823a2dce79d2b3a6784f1f2d5b5864281d 100644 --- a/lite/model_parser/flatbuffers/vector_view.h +++ b/lite/model_parser/flatbuffers/vector_view.h @@ -51,6 +51,7 @@ struct FBSStrIterator { flatbuffers::Offset>::return_type> VI; + FBSStrIterator() = default; explicit FBSStrIterator(const VI& iter) { iter_ = iter; } const VI& raw_iter() const { return iter_; } @@ -104,20 +105,21 @@ class VectorView { explicit VectorView(typename Traits::vector_type const* cvec) { cvec_ = cvec; } - std::string operator[](size_t i) const { - CHECK(cvec_); - return cvec_->operator[](i)->str(); - } + std::string operator[](size_t i) const { return cvec_->operator[](i)->str(); } vector_view::FBSStrIterator begin() const { - CHECK(cvec_); + if (!cvec_) { + return vector_view::FBSStrIterator(); + } return vector_view::FBSStrIterator(cvec_->begin()); } vector_view::FBSStrIterator end() const { - CHECK(cvec_); + if (!cvec_) { + return vector_view::FBSStrIterator(); + } return vector_view::FBSStrIterator(cvec_->end()); } size_t size() const { - if (cvec_ == nullptr) { + if (!cvec_) { return 0; } return cvec_->size(); @@ -126,10 +128,8 @@ class VectorView { VLOG(5) << "Copying elements out of VectorView will damage performance."; std::vector tmp; tmp.reserve(size()); - if (cvec_ != nullptr) { - for (auto val : *cvec_) { - tmp.push_back(val->str()); - } + for (size_t i = 0; i < size(); ++i) { + tmp.push_back(cvec_->operator[](i)->str()); } return tmp; } diff --git a/lite/model_parser/general/CMakeLists.txt b/lite/model_parser/general/CMakeLists.txt index fe3b2f848e404385b8d948db676865b8039f4ba2..ed53678dfac4cc58b208c2faa8573bcd06943aaa 100644 --- a/lite/model_parser/general/CMakeLists.txt +++ b/lite/model_parser/general/CMakeLists.txt @@ -3,4 +3,4 @@ lite_cc_library(cpp_var_desc SRCS var_desc.cc) lite_cc_library(cpp_block_desc SRCS block_desc.cc) lite_cc_library(cpp_program_desc SRCS program_desc.cc) -set(cpp_wrapper cpp_op_desc cpp_var_desc cpp_block_desc cpp_program_desc PARENT_SCOPE) +set(cpp_wrapper cpp_program_desc cpp_block_desc cpp_var_desc cpp_op_desc PARENT_SCOPE) diff --git a/lite/model_parser/pb/var_desc.cc b/lite/model_parser/pb/var_desc.cc index f849b8dd0ed103f789aec41e5c88f3e4f3cdf878..42625ee6190fb98c50de2b88a08b9910d91ed014 100644 --- a/lite/model_parser/pb/var_desc.cc +++ b/lite/model_parser/pb/var_desc.cc @@ -294,9 +294,9 @@ const proto::VarType::TensorDesc &VarDesc::tensor_desc() const { case proto::VarType::LOD_TENSOR_ARRAY: return desc_->type().tensor_array().tensor(); default: - LOG(FATAL) - << "Getting 'tensor_desc' is not supported by the type of var %s." - << this->Name(); + LOG(WARNING) << "Getting 'tensor_desc' is not supported by the type(" + << static_cast(desc_->type().type()) << ") of var " + << this->Name(); } return framework::proto::VarDesc().type().lod_tensor().tensor(); } @@ -312,10 +312,9 @@ std::vector VarDesc::tensor_descs() const { } return res; default: - LOG(FATAL) - << "Getting 'tensor_descs' is not supported by the type of var " - "%s." - << this->Name(); + LOG(WARNING) << "Getting 'tensor_descs' is not supported by the type(" + << static_cast(desc_->type().type()) << ") of var " + << this->Name(); } return std::vector(); } diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 45b49f91ace12da5934471e01afd91c2832f1d6d..4e67acdb228502736f8509fe87556c65e253b82a 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -115,6 +115,7 @@ add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS}) add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${op_DEPS}) add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS}) add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS}) +add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/__xpu__mmdnn_op.cc b/lite/operators/__xpu__mmdnn_op.cc index 35024da911ba0659c5005a1adc641fa3adc2f282..b898c0b132dc0767c8ba28c29098ac998c2cab21 100644 --- a/lite/operators/__xpu__mmdnn_op.cc +++ b/lite/operators/__xpu__mmdnn_op.cc @@ -88,6 +88,78 @@ bool XPUMmdnnBidEmbGrnnAttOp::AttachImpl(const cpp::OpDesc& op_desc, return true; } +bool XPUMmdnnBidEmbGrnnAttOp2::CheckShape() const { return true; } + +bool XPUMmdnnBidEmbGrnnAttOp2::InferShapeImpl() const { + auto& id_dims = param_.id0->dims(); + auto& id_lod = param_.id0->lod()[0]; + auto& emb_tbl_dims = param_.emb_tbl->dims(); + auto& grnn_wh_dims = param_.grnn_rv_wh->dims(); + + param_.emb0_out->Resize({id_dims[0], emb_tbl_dims[1]}); + param_.emb0_out->set_lod({id_lod}); + param_.grnn_fw_pool_out->Resize( + {(int64_t)id_lod.size() - 1, grnn_wh_dims[2]}); + param_.grnn_rv_pool_out->Resize( + {(int64_t)id_lod.size() - 1, grnn_wh_dims[2]}); + param_.att_pool_out->Resize( + {(int64_t)id_lod.size() - 1, 2 * grnn_wh_dims[2]}); + param_.concat_3in1_out->Resize({id_dims[0], 3 * grnn_wh_dims[2]}); + param_.concat_3in1_out->set_lod({id_lod}); + param_.emb_fw_out->Resize({id_dims[0], emb_tbl_dims[1]}); + param_.emb_fw_out->set_lod({id_lod}); + return true; +} + +bool XPUMmdnnBidEmbGrnnAttOp2::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.id0 = + scope->FindVar(op_desc.Input("id0").front())->GetMutable(); + param_.id1 = + scope->FindVar(op_desc.Input("id1").front())->GetMutable(); + param_.emb_tbl = scope->FindVar(op_desc.Input("emb_tbl").front()) + ->GetMutable(); + param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front()) + ->GetMutable(); + param_.grnn_fw_wi = scope->FindVar(op_desc.Input("grnn_fw_wi").front()) + ->GetMutable(); + param_.grnn_rv_wh = scope->FindVar(op_desc.Input("grnn_rv_wh").front()) + ->GetMutable(); + param_.grnn_rv_wi = scope->FindVar(op_desc.Input("grnn_rv_wi").front()) + ->GetMutable(); + param_.att_fc_w = scope->FindVar(op_desc.Input("att_fc_w").front()) + ->GetMutable(); + param_.att_fc_b = scope->FindVar(op_desc.Input("att_fc_b").front()) + ->GetMutable(); + + param_.emb0_out = scope->FindVar(op_desc.Output("emb0_out").front()) + ->GetMutable(); + param_.grnn_fw_pool_out = + scope->FindVar(op_desc.Output("grnn_fw_pool_out").front()) + ->GetMutable(); + param_.grnn_rv_pool_out = + scope->FindVar(op_desc.Output("grnn_rv_pool_out").front()) + ->GetMutable(); + param_.att_pool_out = scope->FindVar(op_desc.Output("att_pool_out").front()) + ->GetMutable(); + param_.concat_3in1_out = + scope->FindVar(op_desc.Output("concat_3in1_out").front()) + ->GetMutable(); + param_.emb_fw_out = scope->FindVar(op_desc.Output("emb_fw_out").front()) + ->GetMutable(); + + param_.grnn_fw_wh_maxs = + op_desc.GetAttr>("grnn_fw_wh_maxs"); + param_.grnn_fw_wi_maxs = + op_desc.GetAttr>("grnn_fw_wi_maxs"); + param_.grnn_rv_wh_maxs = + op_desc.GetAttr>("grnn_rv_wh_maxs"); + param_.grnn_rv_wi_maxs = + op_desc.GetAttr>("grnn_rv_wi_maxs"); + param_.att_fc_w_max = op_desc.GetAttr("att_fc_w_max"); + return true; +} + bool XPUMmdnnBidEmbAttOp::CheckShape() const { return true; } bool XPUMmdnnBidEmbAttOp::InferShapeImpl() const { @@ -157,6 +229,7 @@ bool XPUMmdnnMatchConvTopkOp::AttachImpl(const cpp::OpDesc& op_desc, param_.input_w_max = op_desc.GetAttr("input_w_max"); param_.conv_w_max = op_desc.GetAttr("conv_w_max"); param_.topks = op_desc.GetAttr>("topks"); + param_.output_channel = op_desc.GetAttr("output_channel"); param_.channel_num = op_desc.GetAttr("channel_num"); param_.dim_t = op_desc.GetAttr("dim_t"); return true; @@ -182,10 +255,10 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc, auto t = scope->FindVar(name)->GetMutable(); param_.concat_7in1_x.push_back(t); } - param_.concat_2in1_x.clear(); - for (auto& name : op_desc.Input("concat_2in1_x")) { + param_.concat_topk_x.clear(); + for (auto& name : op_desc.Input("concat_topk_x")) { auto t = scope->FindVar(name)->GetMutable(); - param_.concat_2in1_x.push_back(t); + param_.concat_topk_x.push_back(t); } param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front()) ->GetMutable(); @@ -231,6 +304,8 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc, REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att, paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp); +REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att2, + paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp2); REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_att, paddle::lite::operators::XPUMmdnnBidEmbAttOp); REGISTER_LITE_OP(__xpu__mmdnn_match_conv_topk, diff --git a/lite/operators/__xpu__mmdnn_op.h b/lite/operators/__xpu__mmdnn_op.h index 7038898cad0823746f905e4e60c06885b57a737c..ba815a1eec7d0913bc08b4f8fa520de73a4bb835 100644 --- a/lite/operators/__xpu__mmdnn_op.h +++ b/lite/operators/__xpu__mmdnn_op.h @@ -41,6 +41,29 @@ class XPUMmdnnBidEmbGrnnAttOp : public OpLite { mutable XPUMmdnnBidEmbGrnnAttParam param_; }; +class XPUMmdnnBidEmbGrnnAttOp2 : public OpLite { + public: + XPUMmdnnBidEmbGrnnAttOp2() {} + + explicit XPUMmdnnBidEmbGrnnAttOp2(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "XPUMmdnnBidEmbGrnnAttOp2"; + } + + private: + mutable XPUMmdnnBidEmbGrnnAttParam2 param_; +}; + class XPUMmdnnBidEmbAttOp : public OpLite { public: XPUMmdnnBidEmbAttOp() {} diff --git a/lite/operators/assign_op.cc b/lite/operators/assign_op.cc index fe1e8db1f954af38041621d1d676cf16833357da..f2237230dceda55c89a423e0ee9504ee1e3c1de8 100644 --- a/lite/operators/assign_op.cc +++ b/lite/operators/assign_op.cc @@ -21,15 +21,15 @@ namespace lite { namespace operators { bool AssignOpLite::CheckShape() const { - CHECK_OR_FALSE(param_.X); - CHECK_OR_FALSE(param_.Out); + CHECK_OR_FALSE(param_.X || param_.X_array); + CHECK_OR_FALSE(param_.Out || param_.Out_array); return true; } bool AssignOpLite::InferShapeImpl() const { - if (param_.X != nullptr) { + if (param_.X) { param_.Out->Resize(param_.X->dims()); - } else if (param_.X_array != nullptr) { + } else if (param_.X_array) { param_.Out_array->resize(param_.Out_array->size()); } else { LOG(FATAL) << "x or x_array must be set."; diff --git a/lite/operators/conditional_block_op.cc b/lite/operators/conditional_block_op.cc index e3678e92c9d33be5428c82331ce963f4c6067369..de8bea345fe8da1e157665b93f9d50c6f6bbffa3 100644 --- a/lite/operators/conditional_block_op.cc +++ b/lite/operators/conditional_block_op.cc @@ -20,35 +20,37 @@ namespace paddle { namespace lite { namespace operators { -bool ConditionalBlockOpLite::CheckShape() const { +bool ConditionalBlockOp::CheckShape() const { CHECK_OR_FALSE(param_.cond); - CHECK_OR_FALSE(param_.sub_block); - CHECK_OR_FALSE(param_.scope); + CHECK_OR_FALSE(param_.program_desc); + CHECK_OR_FALSE(param_.exec_scope); return true; } -bool ConditionalBlockOpLite::InferShapeImpl() const { return true; } +bool ConditionalBlockOp::InferShapeImpl() const { return true; } -bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, - lite::Scope *scope) { +bool ConditionalBlockOp::AttachImpl(const cpp::OpDesc& op_desc, Scope* scope) { auto condition = op_desc.Input("Cond").front(); param_.cond = scope->FindVar(condition)->GetMutable(); - auto inputs = op_desc.Input("Input"); - for (auto var : inputs) { - param_.x.push_back(scope->FindVar(var)->GetMutable()); + for (const auto& input : inputs) { + auto* var = scope->FindVar(input); + CHECK(var); + param_.inputs.push_back(var->GetMutable()); } - auto outs = op_desc.Output("Out"); - for (auto var : outs) { - param_.outs.push_back(scope->FindVar(var)->GetMutable()); + for (const auto& out : outs) { + auto* var = scope->FindVar(out); + CHECK(var); + param_.outs.push_back(var->GetMutable()); } - param_.is_scalar_condition = op_desc.GetAttr("is_scalar_condition"); // obtain sub_block in core program.cc - param_.sub_block = sub_block_; - param_.scope = scope; - + CHECK(param_.program_desc); + param_.block_idx = op_desc.GetAttr("sub_block"); + CHECK_GE(param_.block_idx, 0); + param_.exec_scope = scope; + CHECK(param_.exec_scope); return true; } @@ -57,4 +59,4 @@ bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, } // namespace paddle REGISTER_LITE_OP(conditional_block, - paddle::lite::operators::ConditionalBlockOpLite); + paddle::lite::operators::ConditionalBlockOp); diff --git a/lite/operators/conditional_block_op.h b/lite/operators/conditional_block_op.h index 1815731c8df3ac07bee80aa8e0cc658e752b5c4f..adcd8acdff391e2ae3ece9ec21669d853250dcf4 100644 --- a/lite/operators/conditional_block_op.h +++ b/lite/operators/conditional_block_op.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include #include "lite/core/op_lite.h" @@ -23,27 +24,30 @@ namespace paddle { namespace lite { namespace operators { -class ConditionalBlockOpLite : public OpLite { +class ConditionalBlockOp : public OpLite { public: - ConditionalBlockOpLite() {} - explicit ConditionalBlockOpLite(const std::string &op_type) - : OpLite(op_type) {} + ConditionalBlockOp() {} + explicit ConditionalBlockOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; bool InferShapeImpl() const override; - bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "conditional_block"; } - void SetSubBlock(cpp::BlockDesc *desc) { sub_block_ = desc; } + void SetProgramDesc(std::shared_ptr program_desc) { + param_.program_desc = program_desc; + } + std::shared_ptr GetProgramDesc() { + return param_.program_desc; + } private: mutable ConditionalBlockParam param_; - cpp::BlockDesc *sub_block_; }; } // namespace operators diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index f351e8e5344424d80fa79f8d7c83be3bf367441f..240cf65d26e9edde0eb5f9f4efc3b6f7f6a149a6 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -90,9 +90,9 @@ struct SubgraphParam : ParamBase { std::vector output_names{}; std::vector input_data_names{}; std::vector output_data_names{}; - int sub_block_idx{-1}; - cpp::BlockDesc* sub_block_desc{nullptr}; - Scope* scope{nullptr}; + int block_idx{-1}; + std::shared_ptr program_desc{nullptr}; + Scope* exec_scope{nullptr}; }; /// -------------------------- NN operators ------------------------------------ @@ -939,11 +939,10 @@ struct CompareParam : ParamBase { }; struct WhileParam : ParamBase { - Scope* scope{}; Tensor* cond{}; - cpp::BlockDesc* sub_block{}; - std::vector x{}; - std::vector outs{}; + int block_idx{-1}; + std::shared_ptr program_desc{nullptr}; + Scope* exec_scope{nullptr}; }; struct TopkParam : ParamBase { @@ -1454,10 +1453,11 @@ struct MergeLodTensorParam : ParamBase { struct ConditionalBlockParam : ParamBase { const lite::Tensor* cond{}; - std::vector x{}; + std::vector inputs{}; std::vector outs{}; - cpp::BlockDesc* sub_block{}; - Scope* scope{}; + int block_idx{-1}; + std::shared_ptr program_desc{nullptr}; + Scope* exec_scope{nullptr}; bool is_scalar_condition{}; }; @@ -1627,11 +1627,36 @@ struct XPUMmdnnBidEmbGrnnAttParam : ParamBase { std::vector grnn_rv_wi_maxs; float att_fc_w_max{0.0f}; - lite::Tensor* grnn_fw_pool_out{}; // 1 - lite::Tensor* grnn_rv_pool_out{}; // 2 - lite::Tensor* att_pool_out{}; // 3 - lite::Tensor* concat_3in1_out{}; // 4 - lite::Tensor* emb_fw_out{}; // 5 + lite::Tensor* grnn_fw_pool_out{}; + lite::Tensor* grnn_rv_pool_out{}; + lite::Tensor* att_pool_out{}; + lite::Tensor* concat_3in1_out{}; + lite::Tensor* emb_fw_out{}; +}; + +struct XPUMmdnnBidEmbGrnnAttParam2 : ParamBase { + lite::Tensor* id0{}; + lite::Tensor* id1{}; + lite::Tensor* emb_tbl{}; + lite::Tensor* grnn_fw_wh{}; + lite::Tensor* grnn_fw_wi{}; + lite::Tensor* grnn_rv_wh{}; + lite::Tensor* grnn_rv_wi{}; + lite::Tensor* att_fc_w{}; + lite::Tensor* att_fc_b{}; + + std::vector grnn_fw_wh_maxs; + std::vector grnn_fw_wi_maxs; + std::vector grnn_rv_wh_maxs; + std::vector grnn_rv_wi_maxs; + float att_fc_w_max{0.0f}; + + lite::Tensor* emb0_out{}; + lite::Tensor* grnn_fw_pool_out{}; + lite::Tensor* grnn_rv_pool_out{}; + lite::Tensor* att_pool_out{}; + lite::Tensor* concat_3in1_out{}; + lite::Tensor* emb_fw_out{}; }; struct XPUMmdnnBidEmbAttParam : ParamBase { @@ -1643,8 +1668,8 @@ struct XPUMmdnnBidEmbAttParam : ParamBase { float att_fc_w_max{0.0f}; - lite::Tensor* att_pool_out{}; // 1 - lite::Tensor* emb_fw_out{}; // 2 + lite::Tensor* att_pool_out{}; + lite::Tensor* emb_fw_out{}; }; struct XPUMmdnnMatchConvTopkParam : ParamBase { @@ -1656,6 +1681,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase { float input_w_max{0.0f}; float conv_w_max{0.0f}; std::vector topks; + int output_channel{0}; int channel_num{0}; int dim_t{0}; @@ -1664,7 +1690,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase { struct XPUMmdnnMergeAllParam : ParamBase { std::vector concat_7in1_x; - std::vector concat_2in1_x; + std::vector concat_topk_x; lite::Tensor* grnn_fw_wh{}; lite::Tensor* grnn_fw_wi{}; lite::Tensor* grnn_rv_wh{}; @@ -1753,6 +1779,22 @@ struct ClipParam : ParamBase { float max{}; }; +struct PrintParam : ParamBase { + const lite::Tensor* in{}; + lite::Tensor* out{}; + std::string name; + int first_n{-1}; + std::string message; + int summarize{20}; + bool print_tensor_name{true}; + bool print_tensor_type{true}; + bool print_tensor_shape{true}; + bool print_tensor_lod{true}; + bool print_tensor_layout{true}; + std::string print_phase; + bool is_forward{true}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/operators/print_op.cc b/lite/operators/print_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f4299aed06f17d7bf3bd30b9fec34c587168884 --- /dev/null +++ b/lite/operators/print_op.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/print_op.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool PrintOp::CheckShape() const { + CHECK_OR_FALSE(param_.in); + CHECK_OR_FALSE(param_.out); + return true; +} + +bool PrintOp::InferShapeImpl() const { + param_.out->set_lod(param_.in->lod()); + param_.out->Resize(param_.in->dims()); + return true; +} + +bool PrintOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + AttachParam(¶m_); + + param_.name = op_desc.Input("In").front(); + param_.in = scope->FindTensor(param_.name); + param_.out = scope->FindMutableTensor(op_desc.Output("Out").front()); + param_.first_n = op_desc.GetAttr("first_n"); + param_.message = op_desc.GetAttr("message"); + param_.summarize = op_desc.GetAttr("summarize"); + param_.print_tensor_name = op_desc.GetAttr("print_tensor_name"); + param_.print_tensor_type = op_desc.GetAttr("print_tensor_type"); + param_.print_tensor_shape = op_desc.GetAttr("print_tensor_shape"); + param_.print_tensor_lod = op_desc.GetAttr("print_tensor_lod"); + param_.print_tensor_layout = op_desc.GetAttr("print_tensor_layout"); + param_.print_phase = op_desc.GetAttr("print_phase"); + param_.is_forward = op_desc.GetAttr("is_forward"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(print, paddle::lite::operators::PrintOp); diff --git a/lite/operators/print_op.h b/lite/operators/print_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cd8e777b59c3aac92771442402cf16623b75fbef --- /dev/null +++ b/lite/operators/print_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class PrintOp : public OpLite { + public: + PrintOp() {} + explicit PrintOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "print"; } + + private: + mutable PrintParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/subgraph_op.cc b/lite/operators/subgraph_op.cc index 9ac07e96334eda9f0001d33e0789f9de15c4ca67..fec5a0e3254328220508f28a16b110beb01fb613 100644 --- a/lite/operators/subgraph_op.cc +++ b/lite/operators/subgraph_op.cc @@ -39,10 +39,11 @@ bool SubgraphOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { op_desc.GetAttr>("input_data_names"); param_.output_data_names = op_desc.GetAttr>("output_data_names"); - CHECK(param_.sub_block_desc); - param_.sub_block_idx = op_desc.GetAttr("sub_block"); - param_.scope = scope; - CHECK(param_.scope); + CHECK(param_.program_desc); + param_.block_idx = op_desc.GetAttr("sub_block"); + CHECK_GE(param_.block_idx, 0); + param_.exec_scope = scope; + CHECK(param_.exec_scope); return true; } diff --git a/lite/operators/subgraph_op.h b/lite/operators/subgraph_op.h index edbfb922044d60165e589d389cd8cfb3b2547796..df6448f2f78a08f41ac037a13d14cbca1725cfb5 100644 --- a/lite/operators/subgraph_op.h +++ b/lite/operators/subgraph_op.h @@ -13,14 +13,11 @@ // limitations under the License. #pragma once - +#include #include #include -#include "lite/core/kernel.h" #include "lite/core/op_lite.h" #include "lite/core/scope.h" -#include "lite/core/tensor.h" -#include "lite/operators/op_params.h" #include "lite/utils/all.h" namespace paddle { @@ -37,14 +34,18 @@ class SubgraphOp : public OpLite { bool InferShapeImpl() const override; - bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &op_desc, Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "subgraph"; } - void SetSubBlock(cpp::BlockDesc *desc) { param_.sub_block_desc = desc; } - cpp::BlockDesc *GetSubBlock() { return param_.sub_block_desc; } + void SetProgramDesc(std::shared_ptr program_desc) { + param_.program_desc = program_desc; + } + std::shared_ptr GetProgramDesc() { + return param_.program_desc; + } private: mutable SubgraphParam param_; diff --git a/lite/operators/while_op.cc b/lite/operators/while_op.cc index 1dcf9553f331ee6646ad6d93de048728a0886116..ab8e4a5489c13e042bf0d07da1228f33626a1d43 100644 --- a/lite/operators/while_op.cc +++ b/lite/operators/while_op.cc @@ -20,31 +20,23 @@ namespace paddle { namespace lite { namespace operators { -bool WhileOpLite::CheckShape() const { - CHECK_OR_FALSE(param_.sub_block); - CHECK_OR_FALSE(param_.scope); +bool WhileOp::CheckShape() const { CHECK_OR_FALSE(param_.cond); + CHECK_OR_FALSE(param_.program_desc); + CHECK_OR_FALSE(param_.exec_scope); return true; } -bool WhileOpLite::InferShapeImpl() const { return true; } - -bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { - auto inputs = op_desc.Input("X"); - auto outs = op_desc.Output("Out"); - - for (auto var : inputs) { - // param_.x.push_back(scope->FindVar(var)->GetMutable()); - } - for (auto var : outs) { - // param_.outs.push_back(scope->FindVar(var)->GetMutable()); - } - param_.sub_block = sub_block_; +bool WhileOp::InferShapeImpl() const { return true; } +bool WhileOp::AttachImpl(const cpp::OpDesc &op_desc, Scope *scope) { auto condition = op_desc.Input("Condition"); param_.cond = scope->FindVar(condition[0])->GetMutable(); - param_.scope = scope; - + CHECK(param_.program_desc); + param_.block_idx = op_desc.GetAttr("sub_block"); + CHECK_GE(param_.block_idx, 0); + param_.exec_scope = scope; + CHECK(param_.exec_scope); return true; } @@ -52,4 +44,4 @@ bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { } // namespace lite } // namespace paddle -REGISTER_LITE_OP(while, paddle::lite::operators::WhileOpLite); +REGISTER_LITE_OP(while, paddle::lite::operators::WhileOp); diff --git a/lite/operators/while_op.h b/lite/operators/while_op.h index 94aec15a6d3eb60036bf9c2168fdbd855b84a396..e448ee568723b24a241c5bb127ac61458385337e 100644 --- a/lite/operators/while_op.h +++ b/lite/operators/while_op.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include #include "lite/core/op_lite.h" @@ -23,24 +24,30 @@ namespace paddle { namespace lite { namespace operators { -class WhileOpLite : public OpLite { +class WhileOp : public OpLite { public: - WhileOpLite() {} - explicit WhileOpLite(const std::string &op_type) : OpLite(op_type) {} + WhileOp() {} + explicit WhileOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; bool InferShapeImpl() const override; - bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "while"; } - void SetSubBlock(cpp::BlockDesc *desc) { sub_block_ = desc; } + + void SetProgramDesc(std::shared_ptr program_desc) { + param_.program_desc = program_desc; + } + std::shared_ptr GetProgramDesc() { + return param_.program_desc; + } private: mutable WhileParam param_; - cpp::BlockDesc *sub_block_; }; } // namespace operators diff --git a/lite/tests/api/CMakeLists.txt b/lite/tests/api/CMakeLists.txt index 844c3f2ac7146e05b2d93eac76279df022e06652..e9c6574c19bcb6a238503d7b5fc955db9b96d689 100644 --- a/lite/tests/api/CMakeLists.txt +++ b/lite/tests/api/CMakeLists.txt @@ -1,3 +1,13 @@ +if(LITE_WITH_ARM) + lite_cc_test(test_transformer_with_mask_fp32_arm SRCS test_transformer_with_mask_fp32_arm.cc + DEPS ${lite_model_test_DEPS} paddle_api_full + ARM_DEPS ${arm_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/transformer_with_mask_fp32 SERIAL) + if(WITH_TESTING) + add_dependencies(test_transformer_with_mask_fp32_arm extern_lite_download_transformer_with_mask_fp32_tar_gz) + endif() +endif() + if(LITE_WITH_XPU) lite_cc_test(test_resnet50_lite_xpu SRCS test_resnet50_lite_xpu.cc DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils diff --git a/lite/tests/api/test_mmdnn_lite_xpu.cc b/lite/tests/api/test_mmdnn_lite_xpu.cc index a2a98821e70cb462b23887f851cfc4bce6b463ca..72d774db14d955f17caee217f13fddb32acb93c3 100644 --- a/lite/tests/api/test_mmdnn_lite_xpu.cc +++ b/lite/tests/api/test_mmdnn_lite_xpu.cc @@ -26,156 +26,171 @@ DEFINE_bool(perf, false, "perf?"); DEFINE_string(perf_input, "perf_input", "perf_input"); +DEFINE_int32(perf_batch_size, 40, "perf_batch_size"); +DEFINE_bool(use_xpu, true, "use_xpu?"); +DEFINE_int32(perf_dev, 0, "perf_dev"); namespace paddle { namespace lite { -std::vector input0; -std::vector input0_lod = {0}; -std::vector input1; -std::vector input1_lod = {0}; -std::vector input2; -std::vector input2_lod = {0}; -std::vector input3; -std::vector input3_lod = {0}; -std::vector input4; -std::vector input4_lod = {0}; -std::vector input5; -std::vector input5_lod = {0}; +class SampleReader { + public: + std::vector> data; + std::vector> lod; -void ParseInput() { - std::string raw_input = - "0 1;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " - "760166;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 145 " - "10251 839 5 1779 1729 1779 1729 18 2707 6 2707 20 4742 4937 432 6 " - "3869;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 767614 " - "767614 1020808 769579 793958 793958 1050488 911898 751332 751332 750336 " - "750799 750336 751575 751575 751544 751735 751397 751365 751512 751512 " - "753011 751562;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 " - "145 10251 839 2 1211 3 3719 720 1540 145 10251 839 9405 4315 5998 4 2 " - "600 373 41 3719 428 52 44 10251 4302 1319 7 12 2 768 6 918 6 841 870 8 " - "843 8 271;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 " - "767614 767614 1020808 769579 793958 793958 1050488 911898 2 773899 " - "773899 3719 1118420 1118420 1050488 1050488 911898 9405 4315 5998 4 2 " - "785435 785435 41 3719 760166 760166 44 10251 4302 1319 750118 750118 2 " - "750465 750465 750274 750398 750233 751252 751252 753447 752830 753112;\n" - "0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " - "760166;2109 2467 1805 227 3719 428 52 18 1102 10327 252 20 6 242 78 6 " - "532 78;2109 2467 1805 1245431 1245431 760166 760166 18 1035176 1035176 " - "764393 764393 752116 242 750370 750370 752081 751247;2109 2467 1805 227 " - "3719 428 52 18 1102 10327 252 20 2 145 242 1050 252 3582 2212;2109 2467 " - "1805 1245431 1245431 760166 760166 18 1035176 1035176 764393 764393 2 " - "871717 871717 757921 757921 3582 2212;\n" - "0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " - "760166;145 10251 839 76 31 1337 823 7506 567 65 170 8 21293 3719 5 43 " - "394 743 42;1050488 1050488 911898 750016 750016 1337 823 7506 762617 " - "762617 866652 8 21293 3719 5 43 914758 914758 757202;145 10251 839 76 " - "31 1337 823 7506 567 65 170 8 21293 3719 2 17580 30 523324 3 10251 4104 " - "281 3 8511 3719 2217 3 13 226 3083 4 11251 1606 357 9 2 145 10251 839 " - "76 31 1337 823 7506 567 65 170 2 7506 2445 8 145 10251 839 528 839 " - "19670 6538;1050488 1050488 911898 750016 750016 1337 823 7506 762617 " - "762617 866652 8 21293 3719 2 816626 816626 523324 3 1181698 1181698 " - "751656 780821 1063148 3719 2217 3 752498 752498 831323 753602 11251 " - "1606 357 9 2 1050488 1050488 911898 750016 750016 1337 823 7506 762617 " - "762617 866652 2 7506 753045 753045 756756 1050488 911898 528 839 19670 " - "6538;\n" - "0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " - "760166;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 2899 " - "229 10 10 10;1050488 1050488 911898 807966 750273 1035176 1035176 " - "1237875 41 3719 760166 760166 753645 753645 750273 2899 229 750001 " - "750001 750001;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 " - "2899 229 10 10 10 2 1177 8 145 10251 839 99 4 1102 10327 2196 41 3719 " - "428 52 44 99 4 2 101 8 1922 17 2184 2 1154 1922 72 1198 1266 " - "4516;1050488 1050488 911898 807966 750273 1035176 1035176 1237875 41 " - "3719 760166 760166 753645 753645 750273 2899 229 750001 750001 750001 2 " - "750257 750257 756756 1050488 911898 807966 750273 1035176 1035176 " - "1237875 41 3719 760166 760166 753645 753645 750273 2 764513 764513 " - "851213 851213 854628 2 753018 753018 754317 753328 754085 754070;\n" - "0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " - "760166;73 5347 112 8 145 10251 839 262 169 22729 3719 6 743 6 339 1156 " - "78 136 399 693 128 571;776150 776150 112 756756 756756 1050488 911898 " - "791355 791355 22729 3719 6 758277 758277 750137 750234 750241 750178 " - "750055 750216 750212 750049;73 5347 112 8 145 10251 839 262 169 22729 " - "3719 2 588 415 549 415 115 23;776150 776150 112 756756 756756 1050488 " - "911898 791355 791355 22729 3719 2 750221 750221 750262 750277 750277 " - "750261;"; - auto raw_lines = Split(raw_input, "\n"); - for (auto& raw_line : raw_lines) { - auto inputx = Split(raw_line, ";"); - for (size_t i = 1; i < inputx.size(); ++i) { - auto tokens = Split(inputx[i], " "); - static std::vector* const input_array[] = { - &input0, &input0, &input1, &input2, &input3, &input4, &input5}; - static std::vector* const lod_array[] = {&input0_lod, - &input0_lod, - &input1_lod, - &input2_lod, - &input3_lod, - &input4_lod, - &input5_lod}; - for (auto token : tokens) { - input_array[i]->push_back((int64_t)atoi(token.c_str())); - } - lod_array[i]->push_back((uint64_t)tokens.size() + - (*lod_array[i])[lod_array[i]->size() - 1]); - } - } - return; -} + void Read() { + std::string raw_input = + "0 1;125 584 142 2114 197;125 756226 756913 855693 760836;125 584 142 " + "2114 197 10 2899;125 756226 756913 855693 760836 10 750793;125 584 " + "142 2114 197 10 2899 2 825 32 18499 125 584 295 2114 197 2114 2730 6 " + "15 32 18499 125 584 142 295 2114 1423 21 2 334 863 5122 197 974 21 " + "295 619 25 2114 1755 2701 197 15 216 23 18499 125 584 142 599 3228 23 " + "2 5122 1917 804 5 2114 197 1236 3 2114 1403 15 3886 1080 23 1150 125 " + "475 23 2998 23;125 756226 756913 855693 760836 10 750793 2 825 750355 " + "18499 881680 756226 295 765124 760836 2114 872813 754265 15 32 18499 " + "881680 756226 756913 761251 765124 752843 766823 2 334 759834 5122 " + "774643 758458 21 295 755114 25 1148365 1755 2701 197 15 216 23 18499 " + "881680 756226 756913 826848 3228 23 2 5122 831009 804 752371 2114 " + "760836 1236 3 2114 910393 15 3886 1080 23 877375 752137 761034 792123 " + "2998 23;1;1;\n" + "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;121 28 1054 " + "1459 125 72 32 2321 531 125 295 584 142 2114 197 14 477 30 121;121 28 " + "764114 1459 753052 750694 750001 886192 750435 752179 295 584 756913 " + "855693 760836 14 477 30 753504;121 28 1054 1459 125 72 32 2321 531 " + "125 295 584 142 2114 197 2 121 28 1054 1459 125 72 32 2321 531 125 " + "295 584 142 4 263 2114 197 43 95 863 2114 323 20 142 626 11 2 45 10 " + "45 58 142 65 918 741 2114 197 764 3 5122 26 51 1266 2037 295 222 1121 " + "4491 3 545 4338 11 2 5122 26 495 3 142 3444 3249 2114 197 3 626 4 " + "2794;121 28 764114 1459 753052 750694 750001 886192 750435 752179 295 " + "584 756913 855693 760836 2 121 28 764114 1459 753052 750694 750001 " + "886192 750435 752179 295 584 756913 4 750885 2114 760836 43 750030 " + "754302 2114 323 822131 142 626 769001 2 45 750128 750324 58 142 " + "1147454 918 910829 2114 760836 841946 767340 5122 779102 51 1266 2037 " + "756461 222 752031 942669 1139389 780275 4338 830597 2 5122 779102 495 " + "761418 142 3444 852932 2114 760836 3 760162 757966 751127;121 295 " + "5593 142 2114 197;121 295 5593 925208 2114 760836;\n" + "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;207 125 584 " + "142 2114 1423 14 5283 1745 73;207 752276 756226 756913 855693 752843 " + "14 5283 781651 786597;6109 18807 142 5 64 5283 1745 73 3690 1060 3626 " + "4 716 51 1030 2114 197 4 428 936 9066 10 10 10 2 207 125 584 142 2114 " + "1423 2 15329 2114 197 5669 401 318 285 953 4 2114 197 2285 7 1783 11 " + "2 5122 197 14017 584;6109 18807 142 5 755319 5283 781651 786597 3690 " + "1060 3626 4 716 910478 1030 2114 760836 4 750323 936 9066 10 750002 " + "750002 2 207 752276 756226 756913 855693 752843 2 15329 2114 760836 " + "5669 401 318 757541 750261 4 2114 760836 2285 7 757639 11 2 5122 " + "774643 14017 584;125 584 142 1745 5122;125 756226 756913 1745 " + "755836;\n" + "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;149 396 778 " + "584 142 295 2114 1423 14 64 125 584 73 21 36670 5834 10 211 25;149 " + "751876 1048872 584 756913 761251 765124 752843 14 64 125 756226 73 " + "944567 36670 5834 10 750012 753240;101 10 2114 197 3 946 2 149 396 " + "778 584 142 295 2114 1423 2 2610 6 1444 111 2114 948 72 32 21 15 494 " + "25 4 2114 197 5669 1145 2 148 295 149 396 778 584 142 295 21 22853 41 " + "348 619 25 366 5305 2114 807 4 1115 381 1955 2114 11;101 751178 2114 " + "760836 3 946 2 149 751876 1048872 584 756913 761251 765124 752843 2 " + "2610 753567 775165 750899 972788 948 750125 750001 751875 15 494 25 4 " + "2114 760836 5669 1145 2 148 808886 982157 751876 1048872 584 756913 " + "761251 790772 22853 41 348 619 25 366 894206 2114 1008440 4 753953 " + "381 851474 765868 11;149 396 778 584 142 295 2 149 396 354 778 584 " + "142 1333 2 584 778 295 5122 2 149 396 778 584 3609 2 149 396 64478 " + "816 14246 1423 2 149 396 584 32 127 19 3609 2 149 396 584 73 2 149 " + "396 584 778 295 2285 142 4922 323 2 149 396 584 2114 2 149 396 253 " + "584 2114 197;149 751876 1048872 584 756913 761251 2 149 751876 756286 " + "767182 584 756913 1333 2 584 778 897778 941364 2 149 751876 1048872 " + "584 1102835 2 149 751876 64478 816 14246 912094 2 149 751876 584 " + "773547 127 750771 791456 2 149 751876 584 73 2 149 751876 584 778 " + "897778 2285 751493 791984 323 2 149 751876 584 2114 2 149 751876 " + "808443 835481 2114 760836;\n" + "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;125 584 545 " + "149 14 125 584;125 756226 545 874302 14 125 756226;2204 25 30 1692 " + "1770 6534 295 125 584 72 32 1346 4 2698 2114 197 11 2 4235 4301 240 " + "295 125 584 72 32 21 6708 15 56974 494 25 1030 2114 197 110 804 495 " + "611 2 221 759 341 6 5283 1745 73 71 2114 1423 71 125 584 545 149 149 " + "2 505 345 58 125 584 65 3486 2114 295 4 45 786 196 6604 6086;2204 25 " + "30 797189 1770 1191824 295 752782 756226 751697 750001 1346 4 2698 " + "2114 760836 765158 2 4235 4301 240 753859 752782 756226 751697 750001 " + "751875 6708 15 56974 494 25 1030 2114 760836 777607 762850 966521 611 " + "2 221 752565 750130 750084 910219 781651 786597 71 2114 752843 71 125 " + "756226 545 874302 149 2 505 825657 782848 125 756226 65 3486 2114 " + "760669 4 45 755747 758903 6604 6086;125 584 2114 2 125 584 2114 1423 " + "2 125 584 2114 149 2 149 584 1745 5122 725 2 2114 125 584 2 125 584 " + "2114 2 2621 584 2114 2 527 37 2754 130 170 1013 494 887 240 2 4521 " + "11111 586 2321 531 125 584 142 1360 816 2842 1423 2 125 584 2114;125 " + "756226 2114 2 125 756226 2114 752843 2 125 756226 2114 783644 2 149 " + "760183 1745 755836 725 2 2114 125 756226 2 125 756226 2114 2 2621 " + "932600 2114 2 527 751304 869964 754462 170 1013 750719 778287 774620 " + "2 4521 11111 586 2321 750435 752179 756226 756913 1360 764399 2842 " + "1423 2 125 756226 2114;\n" + "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;207 584 142 " + "2114 197 4 207 584 142 2114 197 674 14 240 4328 14 4328 767;207 " + "1237071 756913 855693 760836 4 207 1237071 756913 855693 760836 674 " + "14 240 755573 14 4328 795065;207 584 142 2114 197 2 325 71 71 207 584 " + "142 2114 197 2 876 125 140 2114 197 2 207 584 142 2114 197 674 1210 " + "239 4328 767 268 1349 485 28 4389 504 3 941 57 1419 1978 11;207 " + "1237071 756913 855693 760836 2 325 71 71 207 1237071 756913 855693 " + "760836 2 876 125 750977 1250790 760836 2 207 1237071 756913 855693 " + "760836 674 814792 755820 812174 795065 818859 817155 816597 761001 " + "774461 780904 820475 1109800 790141 790459 780324 770390;584 142 295 " + "2114 232 2 207 584 2114 197 2 584 142 295 2114 232 2 584 142 512 2114 " + "197;584 756913 761251 765124 1006359 2 207 1237071 2114 760836 2 584 " + "756913 761251 765124 1006359 2 584 756913 879930 2114 760836;"; -class MmdnnReader { - std::ifstream ifs; - std::vector StringSplit(const std::string& in, - const std::string& delim) { - std::vector ret; - if (in == "") { - return ret; - } - auto begpos = in.find_first_not_of(delim); - while (begpos != std::string::npos) { - auto endpos = in.find_first_of(delim, begpos); - if (endpos == std::string::npos) { - endpos = in.size(); + auto lines = Split(raw_input, "\n"); + for (auto& line : lines) { + auto split1 = Split(line, ";"); + if (data.size() == 0) { + for (size_t i = 1; i < split1.size(); ++i) { + data.push_back(std::vector()); + lod.push_back({0}); + } } - std::string ssubstr = in.substr(begpos, endpos - begpos); - ret.push_back(ssubstr); - begpos = endpos + 1; - if (endpos >= (in.size() - 1)) { - break; + + for (size_t i = 1; i < split1.size(); ++i) { + auto split2 = Split(split1[i], " "); + if (split2.size() == 0) { + split2.push_back("1280000"); + } + for (auto e : split2) { + data[i - 1].push_back(std::stoi(e.c_str(), nullptr, 0)); + } + lod[i - 1].push_back(lod[i - 1].back() + split2.size()); } } - return ret; } +}; + +class FileReader { + std::ifstream ifs; public: - std::vector data[6]; - std::vector lod[6]; + std::vector> data; + std::vector> lod; void Init(std::string file_name) { ifs.open(file_name); } int Read(int maxline) { - for (int i = 0; i < 6; i++) { - data[i].clear(); - } - for (int i = 0; i < 6; i++) { - lod[i].clear(); - lod[i].push_back(0); - } + data.clear(); + lod.clear(); + std::string line; int cnt = 0; while (cnt < maxline && getline(ifs, line)) { - std::vector split1 = StringSplit(line, ";"); - for (int i = 1; i < 7; i++) { - std::vector split2 = StringSplit(split1[i], " "); + std::vector split1 = Split(line, ";"); + if (data.size() == 0) { + for (size_t i = 1; i < split1.size(); ++i) { + data.push_back(std::vector()); + lod.push_back({0}); + } + } + + for (size_t i = 1; i < split1.size(); i++) { + std::vector split2 = Split(split1[i], " "); if (split2.size() == 0) { split2.push_back("1280000"); } for (size_t j = 0; j < split2.size(); j++) { data[i - 1].push_back(std::stoi(split2[j].c_str(), nullptr, 0)); } - // if (i % 2 == 1) { - // lod[i / 2].push_back(lod[i / 2].back() + split2.size()); - //} lod[i - 1].push_back(lod[i - 1].back() + split2.size()); } cnt++; @@ -186,36 +201,47 @@ class MmdnnReader { TEST(MMDNN, test_mmdnn_lite_xpu) { lite_api::CxxConfig config; - config.set_model_dir(FLAGS_model_dir); - config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, - lite_api::Place{TARGET(kXPU), PRECISION(kInt64)}, - lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, - lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, - lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + // config.set_model_dir(FLAGS_model_dir); + config.set_model_file(FLAGS_model_dir + "/__model__"); + config.set_param_file(FLAGS_model_dir + "/__param__"); + config.set_xpu_dev_per_thread(FLAGS_perf_dev); + if (FLAGS_use_xpu) { + config.set_valid_places( + {lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, + lite_api::Place{TARGET(kXPU), PRECISION(kInt64)}, + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + } else { + config.set_valid_places( + {lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + } config.set_xpu_workspace_l3_size_per_thread(); auto predictor = lite_api::CreatePaddlePredictor(config); if (FLAGS_perf) { - MmdnnReader reader; - reader.Init(FLAGS_perf_input); - int UB_batch = 40; // upper bound of batch + FileReader file_reader; + file_reader.Init(FLAGS_perf_input); + int UB_batch = FLAGS_perf_batch_size; // upper bound of batch int iter = 0; double tsc_sum = 0; while (true) { - int batch = reader.Read(UB_batch); + int batch = file_reader.Read(UB_batch); if (batch <= 0) { break; } ++iter; - for (int i = 0; i < 6; ++i) { + for (size_t i = 0; i < file_reader.data.size(); ++i) { auto input_x = predictor->GetInput(i); - input_x->Resize({(int64_t)reader.data[i].size(), 1}); - input_x->SetLoD({reader.lod[i]}); + input_x->Resize({(int64_t)file_reader.data[i].size(), 1}); + input_x->SetLoD({file_reader.lod[i]}); auto* data_x = input_x->mutable_data(); memcpy(data_x, - reader.data[i].data(), - reader.data[i].size() * sizeof(int64_t)); + file_reader.data[i].data(), + file_reader.data[i].size() * sizeof(int64_t)); } auto start = GetCurrentUS(); @@ -232,55 +258,17 @@ TEST(MMDNN, test_mmdnn_lite_xpu) { return; } - ParseInput(); + SampleReader sample_reader; + sample_reader.Read(); - { - std::vector input0_shape{(int64_t)input0.size(), 1}; - auto input_tensor0 = predictor->GetInput(0); - input_tensor0->Resize(input0_shape); - input_tensor0->SetLoD({input0_lod}); - auto* data0 = input_tensor0->mutable_data(); - memcpy(data0, input0.data(), sizeof(int64_t) * input0.size()); - } - { - std::vector input1_shape{(int64_t)input1.size(), 1}; - auto input_tensor1 = predictor->GetInput(1); - input_tensor1->Resize(input1_shape); - input_tensor1->SetLoD({input1_lod}); - auto* data1 = input_tensor1->mutable_data(); - memcpy(data1, input1.data(), sizeof(int64_t) * input1.size()); - } - { - std::vector input2_shape{(int64_t)input2.size(), 1}; - auto input_tensor2 = predictor->GetInput(2); - input_tensor2->Resize(input2_shape); - input_tensor2->SetLoD({input2_lod}); - auto* data2 = input_tensor2->mutable_data(); - memcpy(data2, input2.data(), sizeof(int64_t) * input2.size()); - } - { - std::vector input3_shape{(int64_t)input3.size(), 1}; - auto input_tensor3 = predictor->GetInput(3); - input_tensor3->Resize(input3_shape); - input_tensor3->SetLoD({input3_lod}); - auto* data3 = input_tensor3->mutable_data(); - memcpy(data3, input3.data(), sizeof(int64_t) * input3.size()); - } - { - std::vector input4_shape{(int64_t)input4.size(), 1}; - auto input_tensor4 = predictor->GetInput(4); - input_tensor4->Resize(input4_shape); - input_tensor4->SetLoD({input4_lod}); - auto* data4 = input_tensor4->mutable_data(); - memcpy(data4, input4.data(), sizeof(int64_t) * input4.size()); - } - { - std::vector input5_shape{(int64_t)input5.size(), 1}; - auto input_tensor5 = predictor->GetInput(5); - input_tensor5->Resize(input5_shape); - input_tensor5->SetLoD({input5_lod}); - auto* data5 = input_tensor5->mutable_data(); - memcpy(data5, input5.data(), sizeof(int64_t) * input5.size()); + for (size_t i = 0; i < sample_reader.data.size(); ++i) { + auto input_x = predictor->GetInput(i); + input_x->Resize({(int64_t)sample_reader.data[i].size(), 1}); + input_x->SetLoD({sample_reader.lod[i]}); + auto* data_x = input_x->mutable_data(); + memcpy(data_x, + sample_reader.data[i].data(), + sample_reader.data[i].size() * sizeof(int64_t)); } for (int i = 0; i < FLAGS_warmup; ++i) { diff --git a/lite/tests/api/test_transformer_with_mask_fp32_arm.cc b/lite/tests/api/test_transformer_with_mask_fp32_arm.cc new file mode 100644 index 0000000000000000000000000000000000000000..e65b017aa1440683d86d0da03686a2be9c4c6ee5 --- /dev/null +++ b/lite/tests/api/test_transformer_with_mask_fp32_arm.cc @@ -0,0 +1,274 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { + +template +void SetTensorData(const std::vector &data, + const std::vector &shape, + paddle::lite_api::Tensor *tensor, + const std::vector> &lod = {}) { + tensor->Resize(shape); + tensor->SetLoD(lod); + std::copy(data.begin(), data.end(), tensor->mutable_data()); +} + +void PrepareInputData( + const std::shared_ptr &predictor, + std::vector src_word_data, + int max_seq_len = 16, // padding + int max_out_len = 8, + int bos_idx = 0, + int eos_idx = 1, + int n_head = 8) { + // src_word + auto src_word = predictor->GetInput(0); + int seq_len = src_word_data.size(); + for (int i = seq_len; i < max_seq_len; i++) { + src_word_data.push_back(eos_idx); + } + std::vector src_word_shape{ + 1, static_cast(src_word_data.size())}; + SetTensorData(src_word_data, src_word_shape, src_word.get()); + // src_pos + auto src_pos = predictor->GetInput(1); + std::vector src_pos_data(src_word_data.size()); + std::iota(src_pos_data.begin(), src_pos_data.end(), 0); + std::vector src_pos_shape{1, + static_cast(src_pos_data.size())}; + SetTensorData(src_pos_data, src_pos_shape, src_pos.get()); + // src_slf_attn_bias + auto src_slf_attn_bias = predictor->GetInput(2); + std::vector src_slf_attn_bias_data(1 * n_head * src_word_data.size() * + src_word_data.size()); + int offset = 0; + for (int j = 0; j < 1 * n_head * src_word_data.size(); j++) { + for (int i = 0; i < seq_len; i++) { + src_slf_attn_bias_data[offset++] = 0.0f; + } + for (int i = seq_len; i < src_word_data.size(); i++) { + src_slf_attn_bias_data[offset++] = -1e9f; + } + } + std::vector src_slf_attn_bias_shape{ + 1, + n_head, + static_cast(src_word_data.size()), + static_cast(src_word_data.size())}; + SetTensorData( + src_slf_attn_bias_data, src_slf_attn_bias_shape, src_slf_attn_bias.get()); + // trg_word + auto trg_word = predictor->GetInput(3); + std::vector trg_word_data(2, 0); + std::vector trg_word_shape{2, 1}; + std::vector lod_level_0{0, 2}; + std::vector lod_level_1{0, 1, 2}; + std::vector> trg_word_lod(2); + trg_word_lod[0] = lod_level_0; + trg_word_lod[1] = lod_level_1; + SetTensorData( + trg_word_data, trg_word_shape, trg_word.get(), trg_word_lod); + // init_score + auto init_score = predictor->GetInput(4); + std::vector init_score_data(2); + init_score_data[0] = 0; + init_score_data[1] = -1e9f; + std::vector init_score_shape{2, 1}; + std::vector> init_score_lod(trg_word_lod); + SetTensorData( + init_score_data, init_score_shape, init_score.get(), init_score_lod); + // init_idx + auto init_idx = predictor->GetInput(5); + std::vector init_idx_data(2, 0); + std::vector init_idx_shape{2}; + SetTensorData(init_idx_data, init_idx_shape, init_idx.get()); + // trg_slf_attn_bias + auto trg_slf_attn_bias = predictor->GetInput(6); + std::vector trg_slf_attn_bias_data(max_out_len * n_head * 1 * + max_out_len); + offset = 0; + for (int k = 0; k < max_out_len; k++) { + for (int j = 0; j < n_head; j++) { + for (int i = 0; i < max_out_len; i++) { + trg_slf_attn_bias_data[offset++] = (i <= k) ? 0.0f : -1e9f; + } + } + } + std::vector trg_slf_attn_bias_shape{ + max_out_len, n_head, 1, max_out_len}; + SetTensorData( + trg_slf_attn_bias_data, trg_slf_attn_bias_shape, trg_slf_attn_bias.get()); + // trg_src_attn_bias + auto trg_src_attn_bias = predictor->GetInput(7); + std::vector trg_src_attn_bias_data(1 * n_head * 1 * + src_word_data.size()); + offset = 0; + for (int j = 0; j < 1 * n_head * 1; j++) { + for (int i = 0; i < seq_len; i++) { + trg_src_attn_bias_data[offset++] = 0.0f; + } + for (int i = seq_len; i < src_word_data.size(); i++) { + trg_src_attn_bias_data[offset++] = -1e9f; + } + } + std::vector trg_src_attn_bias_shape{ + 1, n_head, 1, static_cast(src_word_data.size())}; + SetTensorData( + trg_src_attn_bias_data, trg_src_attn_bias_shape, trg_src_attn_bias.get()); + // kv_padding_selection + auto kv_padding_selection = predictor->GetInput(8); + std::vector kv_padding_selection_data(max_out_len * n_head * + max_out_len * 1); + offset = 0; + for (int k = 0; k < max_out_len; k++) { + for (int j = 0; j < n_head; j++) { + for (int i = 0; i < max_out_len; i++) { + kv_padding_selection_data[offset++] = (i == k) ? 1.0f : 0.0f; + } + } + } + std::vector kv_padding_selection_shape{ + max_out_len, n_head, max_out_len, 1}; + SetTensorData(kv_padding_selection_data, + kv_padding_selection_shape, + kv_padding_selection.get()); +} + +void CheckOutputData( + const std::shared_ptr &predictor, + const std::vector &ref_seq_ids_data, + const std::vector &ref_seq_scores_data) { + // seq_ids + auto seq_ids = predictor->GetOutput(0); + auto seq_ids_shape = seq_ids->shape(); + auto seq_ids_size = std::accumulate(seq_ids_shape.begin(), + seq_ids_shape.end(), + 1, + std::multiplies()); + ASSERT_EQ(seq_ids_size, ref_seq_ids_data.size()); + auto *seq_ids_data = seq_ids->data(); + for (size_t i = 0; i < seq_ids_size; i++) { + EXPECT_EQ(seq_ids_data[i], ref_seq_ids_data[i]); + } + // seq_scores + auto seq_scores = predictor->GetOutput(1); + auto seq_scores_shape = seq_scores->shape(); + auto seq_scores_size = std::accumulate(seq_scores_shape.begin(), + seq_scores_shape.end(), + 1, + std::multiplies()); + ASSERT_EQ(seq_scores_size, ref_seq_scores_data.size()); + auto *seq_scores_data = seq_scores->data(); + for (size_t i = 0; i < seq_scores_size; i++) { + EXPECT_NEAR(seq_scores_data[i], ref_seq_scores_data[i], 1e-5); + } +} + +TEST(TransformerWithMask, test_transformer_with_mask_fp32) { + // Save the optimized model by using full api with CxxConfig + lite_api::CxxConfig cxx_config; + cxx_config.set_model_dir(FLAGS_model_dir); + cxx_config.set_valid_places( + {lite_api::Place{TARGET(kARM), PRECISION(kFloat)}, + lite_api::Place{TARGET(kARM), PRECISION(kInt64)}}); + auto predictor = lite_api::CreatePaddlePredictor(cxx_config); + predictor->SaveOptimizedModel(FLAGS_model_dir + ".nb", + paddle::lite_api::LiteModelType::kNaiveBuffer); + // Load the optimized model and run inference by using light api with + // MobileConfig + paddle::lite_api::MobileConfig mobile_config; + mobile_config.set_model_from_file(FLAGS_model_dir + ".nb"); + mobile_config.set_threads(1); + mobile_config.set_power_mode(paddle::lite_api::PowerMode::LITE_POWER_HIGH); + std::vector, + std::pair, std::vector>>> + test_cases = { + {{16, 16, 16, 1}, + {{0, 16, 16, 16, 16, 16, 16, 1, 0, 16, 16, 16, 16, 16, 9, 1}, + {0.0f, + -0.939061f, + -1.91494f, + -2.94378f, + -4.26457f, + -5.82675f, + -7.45856f, + -7.58065f, + 0.0f, + -0.939061f, + -1.91494f, + -2.94378f, + -4.26457f, + -5.82675f, + -8.70994f, + -8.8053f}}}, + {{16, 16, 16, 10, 1}, + {{0, 6, 53, 11, 1, 0, 6, 53, 56, 4, 1}, + {0.0f, + -2.36122f, + -4.1678f, + -6.19764f, + -7.69256f, + 0.0f, + -2.36122f, + -4.1678f, + -6.20145f, + -7.66355f, + -8.63024f}}}, + {{126, 4, 33, 1}, + {{0, 68, 5, 17, 1, 0, 68, 5, 13, 14, 1}, + {0.0f, + -0.829941f, + -1.20217f, + -2.23938f, + -2.98262f, + 0.0f, + -0.829941f, + -1.20217f, + -2.25051f, + -3.07555f, + -3.57711f}}}, + {{126, 4, 33, 99, 1}, + {{0, 14, 242, 17, 1, 0, 93, 38, 27, 68, 1}, + {0.f, + -1.8504f, + -2.66679f, + -3.09469f, + -3.63227f, + 0.0f, + -1.33829f, + -1.41656f, + -3.1333f, + -3.27901f, + -3.88582f}}}}; + for (auto &test_case : test_cases) { + PrepareInputData(predictor, test_case.first); + predictor->Run(); + CheckOutputData(predictor, test_case.second.first, test_case.second.second); + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/prior_box_compute_test.cc b/lite/tests/kernels/prior_box_compute_test.cc index 73fd612c3a03c0a15ddaf3ce6c08ff0ed1a5a95b..ec0eda8cbb2b7f8d6ab01efa467ed857d817905a 100644 --- a/lite/tests/kernels/prior_box_compute_test.cc +++ b/lite/tests/kernels/prior_box_compute_test.cc @@ -21,7 +21,7 @@ namespace paddle { namespace lite { -const int MALLOC_ALIGN = 64; +const int MALLOC_ALIGN = 16; void* fast_malloc(size_t size) { size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; diff --git a/lite/tests/math/gemm_int8_compute_test.cc b/lite/tests/math/gemm_int8_compute_test.cc index adae19d013e50fbd484257a99f55229c75b94263..57899c8d1e2e0c073f410e90d18119327f21f066 100644 --- a/lite/tests/math/gemm_int8_compute_test.cc +++ b/lite/tests/math/gemm_int8_compute_test.cc @@ -120,6 +120,10 @@ bool test_gemm_int8(bool tra, auto dc_fp32 = tc_fp32.mutable_data(); auto dc_basic_int8 = tc_basic_int8.mutable_data(); auto dc_basic_fp32 = tc_basic_fp32.mutable_data(); + // set intial input to be 0 + memset(reinterpret_cast(dc_basic_fp32), + 0, + tc_basic_fp32.numel() * sizeof(float)); auto dbias = tbias.mutable_data(); if (FLAGS_check_result) { diff --git a/lite/tests/math/gemv_int8_compute_test.cc b/lite/tests/math/gemv_int8_compute_test.cc index 99db53511446ecd4772fa2fd1b202337581506ef..3819c0dcd7f87c69a5805aae643a6a3a4a037f03 100644 --- a/lite/tests/math/gemv_int8_compute_test.cc +++ b/lite/tests/math/gemv_int8_compute_test.cc @@ -108,6 +108,10 @@ bool test_gemv_int8(bool tra, auto dc_basic_int8 = tc_basic_int8.mutable_data(); auto dc_basic_fp32 = tc_basic_fp32.mutable_data(); auto dbias = tbias.mutable_data(); + // set intial input to be 0 + memset(reinterpret_cast(dc_basic_fp32), + 0, + tc_basic_fp32.numel() * sizeof(float)); paddle::lite_api::ActivationType act = paddle::lite_api::ActivationType::kIndentity; diff --git a/lite/tests/math/sgemm_c4_compute_test.cc b/lite/tests/math/sgemm_c4_compute_test.cc index b5beeaffaed6bff8a260c158bdce234fce6c1349..ecdf77fd37fff1da2914eeca5e29ef931de09c53 100644 --- a/lite/tests/math/sgemm_c4_compute_test.cc +++ b/lite/tests/math/sgemm_c4_compute_test.cc @@ -92,6 +92,7 @@ bool test_sgemm_c4( auto db_c4 = tb_c4.mutable_data(); auto dc_basic = tc_basic.mutable_data(); auto dbias = tbias.mutable_data(); + memset(reinterpret_cast(dc_basic), 0, tc_basic.numel()); // trans A, B to c4 basic_trans_mat_to_c4(da, da_c4, k, m, k, true); diff --git a/lite/tests/math/sgemv_compute_test.cc b/lite/tests/math/sgemv_compute_test.cc index 91a1fe1770dfa3eeb3f3b94fcd2361f1c1634b1e..661c4f02aa7eafe807f77767dfd4db01a338993e 100644 --- a/lite/tests/math/sgemv_compute_test.cc +++ b/lite/tests/math/sgemv_compute_test.cc @@ -84,6 +84,7 @@ bool test_sgemv(bool tra, auto db = tb.mutable_data(); auto dc = tc.mutable_data(); auto dc_basic = tc_basic.mutable_data(); + memset(reinterpret_cast(dc_basic), 0, tc_basic.numel()); auto dbias = tbias.mutable_data(); paddle::lite_api::ActivationType act = paddle::lite_api::ActivationType::kIndentity; diff --git a/lite/tools/build.sh b/lite/tools/build.sh index e74a6176401b2bf29006804ec6e2f6e683b09696..6fc14d6033e8fb232bc66c6431db622ae80511fb 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -22,6 +22,7 @@ OPTMODEL_DIR="" BUILD_TAILOR=OFF BUILD_CV=OFF WITH_LOG=ON +WITH_EXCEPTION=OFF WITH_PROFILE=OFF BUILD_NPU=OFF NPU_DDK_ROOT="$(pwd)/ai_ddk_lib/" # Download HiAI DDK from https://developer.huawei.com/consumer/cn/hiai/ @@ -126,6 +127,7 @@ function make_tiny_publish_so { -DLITE_WITH_JAVA=$BUILD_JAVA \ -DLITE_WITH_PYTHON=$BUILD_PYTHON \ -DLITE_WITH_LOG=$WITH_LOG \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_ON_TINY_PUBLISH=ON \ -DANDROID_STL_TYPE=$android_stl \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \ @@ -181,6 +183,7 @@ function make_opencl { -DWITH_TESTING=OFF \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \ -DLITE_WITH_LOG=$WITH_LOG \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_WITH_CV=$BUILD_CV \ -DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 -DARM_TARGET_LANG=$3 @@ -219,6 +222,7 @@ function make_full_publish_so { -DLITE_WITH_JAVA=$BUILD_JAVA \ -DLITE_WITH_PYTHON=$BUILD_PYTHON \ -DLITE_WITH_LOG=$WITH_LOG \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_WITH_PROFILE=${WITH_PROFILE} \ -DANDROID_STL_TYPE=$android_stl \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \ @@ -343,6 +347,8 @@ function make_cuda { -DLITE_WITH_STATIC_CUDA=OFF \ -DLITE_WITH_PYTHON=${BUILD_PYTHON} \ -DLITE_BUILD_EXTRA=ON \ + -DLITE_WITH_LOG=${WITH_LOG} \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_WITH_XPU=$BUILD_XPU \ -DLITE_WITH_XTCL=$BUILD_XTCL \ -DXPU_SDK_ROOT=$XPU_SDK_ROOT @@ -379,6 +385,7 @@ function make_x86 { -DLITE_WITH_PYTHON=${BUILD_PYTHON} \ -DLITE_BUILD_EXTRA=ON \ -DLITE_WITH_LOG=${WITH_LOG} \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_WITH_PROFILE=${WITH_PROFILE} \ -DLITE_WITH_XPU=$BUILD_XPU \ -DLITE_WITH_XTCL=$BUILD_XTCL \ @@ -409,6 +416,7 @@ function print_usage { echo echo -e "optional argument:" echo -e "--with_log: (OFF|ON); controls whether to print log information, default is ON" + echo -e "--with_exception: (OFF|ON); controls whether to throw the exception when error occurs, default is OFF" echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)" echo -e "--build_train: (OFF|ON); controls whether to publish training operators and kernels, build_train is only for full_publish library now" echo -e "--build_python: (OFF|ON); controls whether to publish python api lib (ANDROID and IOS is not supported)" @@ -491,6 +499,17 @@ function main { WITH_LOG="${i#*=}" shift ;; + --with_exception=*) + WITH_EXCEPTION="${i#*=}" + if [[ $WITH_EXCEPTION == "ON" && $ARM_OS=="android" && $ARM_ABI == "armv7" && $ARM_LANG != "clang" ]]; then + set +x + echo + echo -e "error: only clang provide C++ exception handling support for 32-bit ARM." + echo + exit 1 + fi + shift + ;; --with_profile=*) WITH_PROFILE="${i#*=}" shift diff --git a/lite/tools/build_android.sh b/lite/tools/build_android.sh index 5713c4e21bb97d12bb840c99d1adbc7f2d781157..ecf34f0dfc4ddd141af9ea07dd6c4f15d1c0c16b 100755 --- a/lite/tools/build_android.sh +++ b/lite/tools/build_android.sh @@ -17,6 +17,8 @@ WITH_JAVA=ON WITH_CV=OFF # controls whether to hide log information, default is ON. WITH_LOG=ON +# controls whether to throw the exception when error occurs, default is OFF +WITH_EXCEPTION=OFF # options of striping lib according to input model. OPTMODEL_DIR="" WITH_STRIP=OFF @@ -145,6 +147,7 @@ function make_tiny_publish_so { local cmake_mutable_options=" -DLITE_BUILD_EXTRA=$WITH_EXTRA \ -DLITE_WITH_LOG=$WITH_LOG \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_BUILD_TAILOR=$WITH_STRIP \ -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DLITE_WITH_JAVA=$WITH_JAVA \ @@ -194,6 +197,7 @@ function make_full_publish_so { local cmake_mutable_options=" -DLITE_BUILD_EXTRA=$WITH_EXTRA \ -DLITE_WITH_LOG=$WITH_LOG \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_BUILD_TAILOR=$WITH_STRIP \ -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DLITE_WITH_JAVA=$WITH_JAVA \ @@ -237,6 +241,7 @@ function print_usage { echo -e "| --with_java: (OFF|ON); controls whether to publish java api lib, default is ON |" echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |" echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |" + echo -e "| --with_exception: (OFF|ON); controls whether to throw the exception when error occurs, default is OFF |" echo -e "| --with_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP) |" echo -e "| |" echo -e "| arguments of striping lib according to input model:(armv8, gcc, c++_static) |" @@ -320,6 +325,18 @@ function main { WITH_LOG="${i#*=}" shift ;; + # ON or OFF, default OFF + --with_exception=*) + WITH_EXCEPTION="${i#*=}" + if [[ $WITH_EXCEPTION == "ON" && $ARCH == "armv7" && $TOOLCHAIN != "clang" ]]; then + set +x + echo + echo -e "Error: only clang provide C++ exception handling support for 32-bit ARM." + echo + exit 1 + fi + shift + ;; # compiling lib which can operate on opencl and cpu. --with_opencl=*) WITH_OPENCL="${i#*=}" diff --git a/lite/tools/build_ios.sh b/lite/tools/build_ios.sh index 3d4337aa8ecc20fd078b8906a950408927ea56c8..4eea073a058ba9e1e821e9f0746687baa0c38d5f 100755 --- a/lite/tools/build_ios.sh +++ b/lite/tools/build_ios.sh @@ -12,6 +12,8 @@ WITH_EXTRA=OFF WITH_CV=OFF # controls whether to hide log information, default is ON. WITH_LOG=ON +# controls whether to throw the exception when error occurs, default is OFF +WITH_EXCEPTION=OFF # absolute path of Paddle-Lite. workspace=$PWD/$(dirname $0)/../../ # options of striping lib according to input model. @@ -69,6 +71,7 @@ function make_ios { -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ -DLITE_WITH_X86=OFF \ -DLITE_WITH_LOG=$WITH_LOG \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_BUILD_TAILOR=$WITH_STRIP \ -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DARM_TARGET_ARCH_ABI=$arch \ @@ -96,6 +99,7 @@ function print_usage { echo -e "| --arch: (armv8|armv7), default is armv8 |" echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |" echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |" + echo -e "| --with_exception: (OFF|ON); controls whether to throw the exception when error occurs, default is OFF |" echo -e "| --with_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP) |" echo -e "| |" echo -e "| arguments of striping lib according to input model:(armv8, gcc, c++_static) |" @@ -140,6 +144,10 @@ function main { WITH_LOG="${i#*=}" shift ;; + --with_exception=*) + WITH_EXCEPTION="${i#*=}" + shift + ;; help) print_usage exit 0 diff --git a/lite/tools/build_linux.sh b/lite/tools/build_linux.sh index 5ed491cb7da7b33357b7e66ab8267e60815b5348..f6de128feb6073fe206d03b68c5d8bc04dc9f16c 100755 --- a/lite/tools/build_linux.sh +++ b/lite/tools/build_linux.sh @@ -17,6 +17,8 @@ PY_VERSION="" WITH_CV=OFF # controls whether to print log information, default is ON. WITH_LOG=ON +# controls whether to throw the exception when error occurs, default is OFF +WITH_EXCEPTION=OFF # options of striping lib according to input model. WITH_STRIP=OFF OPTMODEL_DIR="" @@ -60,6 +62,7 @@ function init_cmake_mutable_options { -DPY_VERSION=$PY_VERSION \ -DLITE_WITH_CV=$WITH_CV \ -DLITE_WITH_LOG=$WITH_LOG \ + -DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \ -DLITE_BUILD_TAILOR=$WITH_STRIP \ -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DLITE_WITH_OPENCL=$WITH_OPENCL \ @@ -210,6 +213,7 @@ function print_usage { echo -e "| --python_version: (2.7|3.5|3.7); controls python version to compile whl, default is None |" echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |" echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |" + echo -e "| --with_exception: (OFF|ON); controls whether to throw the exception when error occurs, default is OFF |" echo -e "| |" echo -e "| arguments of striping lib according to input model: |" echo -e "| ./lite/tools/build_linux.sh --with_strip=ON --opt_model_dir=YourOptimizedModelDir |" @@ -280,6 +284,11 @@ function main { shift ;; # ON or OFF, default OFF + --with_exception=*) + WITH_EXCEPTION="${i#*=}" + shift + ;; + # ON or OFF, default OFF --with_strip=*) BUILD_TAILOR="${i#*=}" shift diff --git a/lite/tools/ci_build.sh b/lite/tools/ci_build.sh index 05494931b9025d5fa64b9069325445b363201ba0..713eaa808ffb3c96a048b66303062e08a357edb3 100755 --- a/lite/tools/ci_build.sh +++ b/lite/tools/ci_build.sh @@ -415,7 +415,7 @@ function test_arm_android { echo "test name: ${test_name}" adb_work_dir="/data/local/tmp" - skip_list=("test_model_parser" "test_mobilenetv1" "test_mobilenetv2" "test_resnet50" "test_inceptionv4" "test_light_api" "test_apis" "test_paddle_api" "test_cxx_api" "test_gen_code" "test_mobilenetv1_int8" "test_subgraph_pass" "test_grid_sampler_image_opencl" "test_lrn_image_opencl" "test_pad2d_image_opencl") + skip_list=("test_model_parser" "test_mobilenetv1" "test_mobilenetv2" "test_resnet50" "test_inceptionv4" "test_light_api" "test_apis" "test_paddle_api" "test_cxx_api" "test_gen_code" "test_mobilenetv1_int8" "test_subgraph_pass" "test_grid_sampler_image_opencl" "test_lrn_image_opencl" "test_pad2d_image_opencl" "test_transformer_with_mask_fp32_arm") for skip_name in ${skip_list[@]} ; do [[ $skip_name =~ (^|[[:space:]])$test_name($|[[:space:]]) ]] && echo "skip $test_name" && return done @@ -1199,6 +1199,7 @@ function main { build_test_arm_subtask_model test_mobilenetv2 mobilenet_v2_relu build_test_arm_subtask_model test_resnet50 resnet50 build_test_arm_subtask_model test_inceptionv4 inception_v4_simple + build_test_arm_subtask_model test_transformer_with_mask_fp32_arm transformer_with_mask_fp32 shift ;; build_test_arm_subtask_armlinux) diff --git a/lite/utils/env.h b/lite/utils/env.h index f3bb8b58e1b63ed2c0ed05792020d11ea307690c..1d26148cea1ed499c8d5ca408ae9235788be6e91 100644 --- a/lite/utils/env.h +++ b/lite/utils/env.h @@ -15,14 +15,23 @@ #pragma once #include #include - #include #include +// Specify the path of configuration file for the subgraph segmentation, an +// example is shown as below: +// op_type:in_var_name_0,in_var_name1:out_var_name_0,out_var_name1 +// op_type::out_var_name_0 +// op_type:in_var_name_0 +// op_type #define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \ "SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE" -#define SUBGRAPH_DISABLE_ONLINE_MODE "SUBGRAPH_DISABLE_ONLINE_MODE" +// The original weight/local/unused variables in the subblock of the subgraph op +// will be saved only if 'SUBGRAPH_ONLINE_MODE' is set to true(default) during +// the analysis phase, it ensure the ops in the subblock can be converted to the +// target device model online during the execution phase. +#define SUBGRAPH_ONLINE_MODE "SUBGRAPH_ONLINE_MODE" namespace paddle { namespace lite { diff --git a/lite/utils/logging.h b/lite/utils/logging.h index f292f220c006135af664ea34acc03525a5c112ab..c7fa8d4cf113abebb29c4ebe972e243a39573cf0 100644 --- a/lite/utils/logging.h +++ b/lite/utils/logging.h @@ -57,7 +57,7 @@ static int gettimeofday(struct timeval* tp, void* tzp) { #include "lite/utils/replace_stl/stream.h" #include "lite/utils/string.h" -#ifdef LITE_WITH_ANDROID +#if defined(LITE_WITH_LOG) && defined(LITE_WITH_ANDROID) #include // Android log macors #define ANDROID_LOG_TAG "Paddle-Lite" @@ -143,8 +143,10 @@ class LogMessage { ANDROID_LOG_I(log_stream_.str().c_str()); } else if (level_ == "W") { ANDROID_LOG_W(log_stream_.str().c_str()); + } else if (level_ == "F") { + ANDROID_LOG_F(log_stream_.str().c_str()); } else { - fprintf(stderr, "Unsupported log level: %s", level_.c_str()); + fprintf(stderr, "Unsupported log level: %s\n", level_.c_str()); assert(false); } #endif @@ -170,17 +172,25 @@ class LogMessageFatal : public LogMessage { const char* level = "F") : LogMessage(file, func, lineno, level) {} - ~LogMessageFatal() { + ~LogMessageFatal() +#ifdef LITE_WITH_EXCEPTION + noexcept(false) +#endif + { log_stream_ << '\n'; #ifdef LITE_WITH_ANDROID ANDROID_LOG_F(log_stream_.str().c_str()); #endif fprintf(stderr, "%s", log_stream_.str().c_str()); +#ifdef LITE_WITH_EXCEPTION + throw std::exception(); +#else #ifndef LITE_ON_TINY_PUBLISH abort(); #else assert(false); +#endif #endif } }; @@ -237,7 +247,11 @@ class Voidify { class VoidifyFatal : public Voidify { public: +#ifdef LITE_WITH_EXCEPTION + ~VoidifyFatal() noexcept(false) { throw std::exception(); } +#else ~VoidifyFatal() { assert(false); } +#endif }; #endif