提交 47ebc8c4 编写于 作者: J jiweibo

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into stream_manage

...@@ -98,6 +98,7 @@ lite_option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OF ...@@ -98,6 +98,7 @@ lite_option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OF
lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF)
lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF) lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF)
lite_option(LITE_WITH_LOG "Enable log printing or not." ON) lite_option(LITE_WITH_LOG "Enable log printing or not." ON)
lite_option(LITE_WITH_EXCEPTION "Enable throwing the exception when error occurs in lite" OFF)
lite_option(LITE_WITH_NVTX "Enable nvtx or not, please enable LITE_WITH_CUDA first." OFF) lite_option(LITE_WITH_NVTX "Enable nvtx or not, please enable LITE_WITH_CUDA first." OFF)
lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF) lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF)
lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF) lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF)
......
...@@ -190,6 +190,10 @@ if (LITE_WITH_LOG) ...@@ -190,6 +190,10 @@ if (LITE_WITH_LOG)
add_definitions("-DLITE_WITH_LOG") add_definitions("-DLITE_WITH_LOG")
endif() endif()
if (LITE_WITH_EXCEPTION)
add_definitions("-DLITE_WITH_EXCEPTION")
endif()
if (LITE_ON_TINY_PUBLISH) if (LITE_ON_TINY_PUBLISH)
add_definitions("-DLITE_ON_TINY_PUBLISH") add_definitions("-DLITE_ON_TINY_PUBLISH")
endif() endif()
......
...@@ -80,6 +80,17 @@ if (ARM_TARGET_LANG STREQUAL "clang") ...@@ -80,6 +80,17 @@ if (ARM_TARGET_LANG STREQUAL "clang")
elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7") elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7")
set(triple arm-v7a-linux-android) set(triple arm-v7a-linux-android)
set(LITE_WITH_OPENMP OFF CACHE STRING "Due to libomp's bug(For ARM64, it has been fixed by https://reviews.llvm.org/D19879, but still exists on ARM32), disable OpenMP on armv7 when cross-compiling using Clang" FORCE) set(LITE_WITH_OPENMP OFF CACHE STRING "Due to libomp's bug(For ARM64, it has been fixed by https://reviews.llvm.org/D19879, but still exists on ARM32), disable OpenMP on armv7 when cross-compiling using Clang" FORCE)
if(ANDROID_STL_TYPE MATCHES "^c\\+\\+_")
# Use CMAKE_CXX_STANDARD_LIBRARIES_INIT to ensure libunwind and libc++ is linked in the right order
set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libunwind.a")
if(ANDROID_STL_TYPE STREQUAL "c++_shared")
set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libc++_shared.so")
elseif(ANDROID_STL_TYPE STREQUAL "c++_static")
set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libc++_static.a")
else()
message(FATAL_ERROR "Invalid Android STL TYPE: ${ANDROID_STL_TYPE}.")
endif()
endif()
else() else()
message(FATAL_ERROR "Clang do not support this ${ARM_TARGET_ARCH_ABI}, use armv8 or armv7") message(FATAL_ERROR "Clang do not support this ${ARM_TARGET_ARCH_ABI}, use armv8 or armv7")
endif() endif()
......
...@@ -23,6 +23,21 @@ if(ANDROID) ...@@ -23,6 +23,21 @@ if(ANDROID)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -llog -fPIC") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -llog -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -llog -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -llog -fPIC")
# Don't re-export libgcc symbols
set(REMOVE_ATOMIC_GCC_SYMBOLS "-Wl,--exclude-libs,libatomic.a -Wl,--exclude-libs,libgcc.a")
set(CMAKE_SHARED_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_SHARED_LINKER_FLAGS}")
set(CMAKE_MODULE_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_MODULE_LINKER_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${REMOVE_ATOMIC_GCC_SYMBOLS} ${CMAKE_EXE_LINKER_FLAGS}")
# Only the libunwind.a from clang(with libc++) provide C++ exception handling support for 32-bit ARM
# Refer to https://android.googlesource.com/platform/ndk/+/master/docs/BuildSystemMaintainers.md#Unwinding
if (ARM_TARGET_LANG STREQUAL "clang" AND ARM_TARGET_ARCH_ABI STREQUAL "armv7" AND ANDROID_STL_TYPE MATCHES "^c\\+\\+_")
set(REMOVE_UNWIND_SYMBOLS "-Wl,--exclude-libs,libunwind.a")
set(CMAKE_SHARED_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_SHARED_LINKER_FLAGS}")
set(CMAKE_MODULE_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_MODULE_LINKER_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${REMOVE_UNWIND_SYMBOLS} ${CMAKE_EXE_LINKER_FLAGS}")
endif()
endif() endif()
if(ARMLINUX) if(ARMLINUX)
...@@ -59,14 +74,13 @@ function(check_linker_flag) ...@@ -59,14 +74,13 @@ function(check_linker_flag)
endfunction() endfunction()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
if((LITE_WITH_OPENCL AND (ARM_TARGET_LANG STREQUAL "clang")) OR LITE_WITH_PYTHON OR LITE_WITH_EXCEPTION OR (NOT LITE_ON_TINY_PUBLISH))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fasynchronous-unwind-tables -funwind-tables")
else ()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -fno-asynchronous-unwind-tables -fno-unwind-tables")
endif()
if (LITE_ON_TINY_PUBLISH) if (LITE_ON_TINY_PUBLISH)
if((NOT LITE_WITH_PYTHON)) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions")
endif()
if(LITE_WITH_OPENCL AND (ARM_TARGET_LANG STREQUAL "clang"))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer -fno-asynchronous-unwind-tables -fno-unwind-tables")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -ffunction-sections") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -ffunction-sections")
check_linker_flag(-Wl,--gc-sections) check_linker_flag(-Wl,--gc-sections)
endif() endif()
......
...@@ -54,6 +54,11 @@ find_library(NPU_DDK_IR_BUILD_FILE NAMES hiai_ir_build ...@@ -54,6 +54,11 @@ find_library(NPU_DDK_IR_BUILD_FILE NAMES hiai_ir_build
PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH} PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH}
NO_DEFAULT_PATH) NO_DEFAULT_PATH)
# Added in HiAI DDK 320 or later version
find_library(NPU_DDK_HCL_FILE NAMES hcl
PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH}
NO_DEFAULT_PATH)
if(NOT NPU_DDK_HIAI_FILE) if(NOT NPU_DDK_HIAI_FILE)
message(FATAL_ERROR "Can not find NPU_DDK_HIAI_FILE in ${NPU_DDK_ROOT}") message(FATAL_ERROR "Can not find NPU_DDK_HIAI_FILE in ${NPU_DDK_ROOT}")
else() else()
...@@ -78,5 +83,13 @@ else() ...@@ -78,5 +83,13 @@ else()
set_property(TARGET npu_ddk_ir_build PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_BUILD_FILE}) set_property(TARGET npu_ddk_ir_build PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_BUILD_FILE})
endif() endif()
set(npu_runtime_libs npu_ddk_hiai CACHE INTERNAL "npu ddk runtime libs") if(NOT NPU_DDK_HCL_FILE)
# message(FATAL_ERROR "Can not find NPU_DDK_HCL_FILE in ${NPU_DDK_ROOT}")
else()
message(STATUS "Found NPU_DDK HCL Library: ${NPU_DDK_HCL_FILE}")
add_library(npu_ddk_hcl SHARED IMPORTED GLOBAL)
set_property(TARGET npu_ddk_hcl PROPERTY IMPORTED_LOCATION ${NPU_DDK_HCL_FILE})
endif()
set(npu_runtime_libs npu_ddk_hiai npu_ddk_hcl CACHE INTERNAL "npu ddk runtime libs")
set(npu_builder_libs npu_ddk_ir npu_ddk_ir_build CACHE INTERNAL "npu ddk builder libs") set(npu_builder_libs npu_ddk_ir npu_ddk_ir_build CACHE INTERNAL "npu ddk builder libs")
...@@ -45,6 +45,7 @@ if (WITH_TESTING) ...@@ -45,6 +45,7 @@ if (WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "MobileNetV1_quant.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "MobileNetV1_quant.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "transformer_with_mask_fp32.tar.gz")
endif() endif()
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz")
......
...@@ -37,8 +37,7 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -37,8 +37,7 @@ void Predictor::SaveModel(const std::string &dir,
if (!program_) { if (!program_) {
GenRuntimeProgram(); GenRuntimeProgram();
} }
program_->SaveOpInfosToProgram(program_desc_.get()); program_->SaveToProgram(program_desc_);
program_->UpdateVarsOfProgram(program_desc_.get());
switch (model_type) { switch (model_type) {
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
SaveModelPb(dir, *program_->exec_scope(), *program_desc_.get(), true); SaveModelPb(dir, *program_->exec_scope(), *program_desc_.get(), true);
...@@ -58,17 +57,21 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -58,17 +57,21 @@ void Predictor::SaveModel(const std::string &dir,
void Predictor::SaveOpKernelInfo(const std::string &model_dir) { void Predictor::SaveOpKernelInfo(const std::string &model_dir) {
std::set<std::string> ops_info; std::set<std::string> ops_info;
std::set<std::string> kernels_info; std::set<std::string> kernels_info;
const auto &instructions_ = program_->instructions(); auto block_size = program_->block_size();
for (auto &node : instructions_) { for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
// parse op type infomation const auto &insts = program_->instructions(block_idx);
auto op = node.op()->op_info(); for (auto &inst : insts) {
ops_info.insert(op->Type()); // parse op type infomation
// parse kernel type information auto op = inst.op()->op_info();
std::string kernel_type_str = ops_info.insert(op->Type());
node.kernel()->op_type() + "," + TargetRepr(node.kernel()->target()) + // parse kernel type information
"," + PrecisionRepr(node.kernel()->precision()) + "," + std::string kernel_type_str =
DataLayoutRepr(node.kernel()->layout()) + "," + node.kernel()->alias(); inst.kernel()->op_type() + "," + TargetRepr(inst.kernel()->target()) +
kernels_info.insert(kernel_type_str); "," + PrecisionRepr(inst.kernel()->precision()) + "," +
DataLayoutRepr(inst.kernel()->layout()) + "," +
inst.kernel()->alias();
kernels_info.insert(kernel_type_str);
}
} }
// get souce_file name from op type and kernel type // get souce_file name from op type and kernel type
...@@ -170,9 +173,9 @@ void Predictor::PrepareFeedFetch() { ...@@ -170,9 +173,9 @@ void Predictor::PrepareFeedFetch() {
std::vector<const cpp::OpDesc *> feeds; std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs; std::vector<const cpp::OpDesc *> fetchs;
const auto &insts = program_->instructions(); const auto &insts = program_->instructions(kRootBlockIdx);
for (size_t i = 0; i < program_->num_instructions(); i++) { for (auto &inst : insts) {
const auto &op = insts[i].op()->op_info(); const auto &op = inst.op()->op_info();
if (op->Type() == "feed") { if (op->Type() == "feed") {
feeds.push_back(op); feeds.push_back(op);
} else if (op->Type() == "fetch") { } else if (op->Type() == "fetch") {
...@@ -255,7 +258,6 @@ void Predictor::Build(const lite_api::CxxConfig &config, ...@@ -255,7 +258,6 @@ void Predictor::Build(const lite_api::CxxConfig &config,
} else { } else {
LOG(INFO) << "Load model from file."; LOG(INFO) << "Load model from file.";
} }
Build(model_path, Build(model_path,
model_file, model_file,
param_file, param_file,
...@@ -296,10 +298,10 @@ void Predictor::Build(const std::string &model_path, ...@@ -296,10 +298,10 @@ void Predictor::Build(const std::string &model_path,
Build(program_desc_, valid_places, passes); Build(program_desc_, valid_places, passes);
} }
void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc, void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &program_desc,
const std::vector<Place> &valid_places, const std::vector<Place> &valid_places,
const std::vector<std::string> &passes) { const std::vector<std::string> &passes) {
program_desc_ = desc; program_desc_ = program_desc;
// `inner_places` is used to optimize passes // `inner_places` is used to optimize passes
std::vector<Place> inner_places = valid_places; std::vector<Place> inner_places = valid_places;
for (auto &valid_place : valid_places) { for (auto &valid_place : valid_places) {
...@@ -336,7 +338,7 @@ void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc, ...@@ -336,7 +338,7 @@ void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc,
Place{TARGET(kARM), PRECISION(kInt8)}); Place{TARGET(kARM), PRECISION(kInt8)});
} }
Program program(*desc.get(), scope_, inner_places); Program program(program_desc_, scope_, inner_places);
valid_places_ = inner_places; valid_places_ = inner_places;
core::KernelPickFactor factor; core::KernelPickFactor factor;
......
...@@ -58,13 +58,12 @@ class LITE_API Predictor { ...@@ -58,13 +58,12 @@ class LITE_API Predictor {
// Create a predictor with the weight variable scope set. // Create a predictor with the weight variable scope set.
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope) explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {} : scope_(root_scope) {}
Predictor(const std::shared_ptr<cpp::ProgramDesc>& desc, Predictor(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root_scope,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {}) const std::vector<std::string>& vars_to_clone = {})
: program_desc_(desc), scope_(root) { : program_desc_(program_desc), scope_(root_scope) {
Program program(*desc.get(), scope_, valid_places, var_names); Program program(program_desc_, scope_, valid_places, vars_to_clone);
// TODO(wilber): rethink a new way to associate config and passes.
optimizer_ = Optimizer(std::move(program), valid_places); optimizer_ = Optimizer(std::move(program), valid_places);
exec_scope_ = optimizer_.exec_scope(); exec_scope_ = optimizer_.exec_scope();
valid_places_ = valid_places; valid_places_ = valid_places;
...@@ -86,30 +85,28 @@ class LITE_API Predictor { ...@@ -86,30 +85,28 @@ class LITE_API Predictor {
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool memory_from_memory = false); bool memory_from_memory = false);
void Build(const std::shared_ptr<cpp::ProgramDesc>& desc, void Build(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}); const std::vector<std::string>& passes = {});
std::shared_ptr<Predictor> Clone() const { std::shared_ptr<Predictor> Clone() const {
auto predictor = return std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
return predictor;
} }
std::shared_ptr<Predictor> Clone( std::shared_ptr<Predictor> Clone(
const std::vector<std::string>& var_names) const { const std::vector<std::string>& vars_to_clone) const {
CHECK(program_desc_) << "Both program and scope of current predicotr " CHECK(program_desc_) << "Both program and scope of current predicotr "
"should be not be nullptr in Clone mode."; "should be not be nullptr in Clone mode.";
CHECK(scope_) << "Both program and scope of current predicotr should be " CHECK(scope_) << "Both program and scope of current predicotr should be "
"not be nullptr in Clone mode."; "not be nullptr in Clone mode.";
auto predictor = std::make_shared<Predictor>( auto predictor = std::make_shared<Predictor>(
program_desc_, scope_, valid_places_, var_names); program_desc_, scope_, valid_places_, vars_to_clone);
for (auto i : var_names) { for (auto var_name : vars_to_clone) {
predictor->exec_scope_->LocalVar(i); predictor->exec_scope_->LocalVar(var_name);
auto* tensor = predictor->scope_->Var(i)->GetMutable<lite::Tensor>(); auto* tensor = predictor->scope_->Var(var_name)->GetMutable<Tensor>();
auto* sub_tensor = auto* sub_tensor =
predictor->exec_scope_->Var(i)->GetMutable<lite::Tensor>(); predictor->exec_scope_->Var(var_name)->GetMutable<Tensor>();
sub_tensor->CopyDataFrom(*tensor); sub_tensor->CopyDataFrom(*tensor);
} }
return predictor; return predictor;
...@@ -147,6 +144,7 @@ class LITE_API Predictor { ...@@ -147,6 +144,7 @@ class LITE_API Predictor {
// get a const tensor according to its name // get a const tensor according to its name
const lite::Tensor* GetTensor(const std::string& name) const; const lite::Tensor* GetTensor(const std::string& name) const;
const RuntimeProgram& runtime_program() const; const RuntimeProgram& runtime_program() const;
Scope* scope() { return scope_.get(); }
// This method is disabled in mobile, for unnecessary dependencies required. // This method is disabled in mobile, for unnecessary dependencies required.
void SaveModel( void SaveModel(
......
...@@ -75,8 +75,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -75,8 +75,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
mode_ = config.power_mode(); mode_ = config.power_mode();
threads_ = config.threads(); threads_ = config.threads();
#ifdef LITE_WITH_NPU #ifdef LITE_WITH_NPU
// Store the model-level configuration into scope for kernels, and use
// exe_scope to store the execution-level configuration
Context<TargetType::kNPU>::SetSubgraphModelCacheDir( Context<TargetType::kNPU>::SetSubgraphModelCacheDir(
config.subgraph_model_cache_dir()); raw_predictor_->scope(), config.subgraph_model_cache_dir());
#endif #endif
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ #if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL) !(defined LITE_ON_MODEL_OPTIMIZE_TOOL)
......
...@@ -22,16 +22,16 @@ namespace lite { ...@@ -22,16 +22,16 @@ namespace lite {
void LightPredictor::Build(const std::string& lite_model_file, void LightPredictor::Build(const std::string& lite_model_file,
bool model_from_memory) { bool model_from_memory) {
if (model_from_memory) { if (model_from_memory) {
LoadModelNaiveFromMemory(lite_model_file, scope_.get(), &cpp_program_desc_); LoadModelNaiveFromMemory(
lite_model_file, scope_.get(), program_desc_.get());
} else { } else {
LoadModelNaiveFromFile(lite_model_file, scope_.get(), &cpp_program_desc_); LoadModelNaiveFromFile(lite_model_file, scope_.get(), program_desc_.get());
} }
// For weight quantization of post training, load the int8/16 weights // For weight quantization of post training, load the int8/16 weights
// for optimized model, and dequant it to fp32. // for optimized model, and dequant it to fp32.
DequantizeWeight(); DequantizeWeight();
BuildRuntimeProgram(program_desc_);
BuildRuntimeProgram(cpp_program_desc_);
PrepareFeedFetch(); PrepareFeedFetch();
} }
...@@ -43,15 +43,15 @@ void LightPredictor::Build(const std::string& model_dir, ...@@ -43,15 +43,15 @@ void LightPredictor::Build(const std::string& model_dir,
switch (model_type) { switch (model_type) {
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
LoadModelPb(model_dir, "", "", scope_.get(), &cpp_program_desc_); LoadModelPb(model_dir, "", "", scope_.get(), program_desc_.get());
break; break;
#endif #endif
case lite_api::LiteModelType::kNaiveBuffer: { case lite_api::LiteModelType::kNaiveBuffer: {
if (model_from_memory) { if (model_from_memory) {
LoadModelNaiveFromMemory( LoadModelNaiveFromMemory(
model_buffer, param_buffer, scope_.get(), &cpp_program_desc_); model_buffer, param_buffer, scope_.get(), program_desc_.get());
} else { } else {
LoadModelNaive(model_dir, scope_.get(), &cpp_program_desc_); LoadModelNaive(model_dir, scope_.get(), program_desc_.get());
} }
break; break;
} }
...@@ -60,7 +60,7 @@ void LightPredictor::Build(const std::string& model_dir, ...@@ -60,7 +60,7 @@ void LightPredictor::Build(const std::string& model_dir,
} }
DequantizeWeight(); DequantizeWeight();
BuildRuntimeProgram(cpp_program_desc_); BuildRuntimeProgram(program_desc_);
PrepareFeedFetch(); PrepareFeedFetch();
} }
...@@ -109,15 +109,17 @@ std::vector<std::string> LightPredictor::GetOutputNames() { ...@@ -109,15 +109,17 @@ std::vector<std::string> LightPredictor::GetOutputNames() {
} }
// append the names of inputs and outputs into input_names_ and output_names_ // append the names of inputs and outputs into input_names_ and output_names_
void LightPredictor::PrepareFeedFetch() { void LightPredictor::PrepareFeedFetch() {
auto current_block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(0); std::vector<const cpp::OpDesc*> feeds;
std::vector<cpp::OpDesc*> feeds; std::vector<const cpp::OpDesc*> fetchs;
std::vector<cpp::OpDesc*> fetchs; std::shared_ptr<const cpp::ProgramDesc> program_desc = program_desc_;
for (size_t i = 0; i < current_block->OpsSize(); i++) { auto main_block = program_desc->GetBlock<cpp::BlockDesc>(kRootBlockIdx);
auto op = current_block->GetOp<cpp::OpDesc>(i); auto op_size = main_block->OpsSize();
if (op->Type() == "feed") { for (size_t op_idx = 0; op_idx < op_size; ++op_idx) {
feeds.push_back(op); auto op_desc = main_block->GetOp<cpp::OpDesc>(op_idx);
} else if (op->Type() == "fetch") { if (op_desc->Type() == "feed") {
fetchs.push_back(op); feeds.push_back(op_desc);
} else if (op_desc->Type() == "fetch") {
fetchs.push_back(op_desc);
} }
} }
input_names_.resize(feeds.size()); input_names_.resize(feeds.size());
...@@ -132,54 +134,35 @@ void LightPredictor::PrepareFeedFetch() { ...@@ -132,54 +134,35 @@ void LightPredictor::PrepareFeedFetch() {
} }
} }
void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { void LightPredictor::BuildRuntimeProgram(
std::vector<Instruction> insts; const std::shared_ptr<const cpp::ProgramDesc>& program_desc) {
// 1. Create op first auto* exe_scope = &scope_->NewScope();
Program program(prog, scope_, {}); // Prepare workspace
scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
// 2. Create Instructs scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
#ifdef LITE_WITH_OPENCL CHECK(program_desc);
using OpenCLContext = Context<TargetType::kOpenCL>; auto block_size = program_desc->BlocksSize();
std::unique_ptr<KernelContext> local_ctx(new KernelContext()); CHECK(block_size);
local_ctx->As<OpenCLContext>().InitOnce(); for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
#endif auto block_desc = program_desc->GetBlock<cpp::BlockDesc>(block_idx);
auto var_size = block_desc->VarsSize();
// Create the kernels of the target places, and filter out the specific for (size_t var_idx = 0; var_idx < var_size; ++var_idx) {
// kernel with the target alias. auto var_desc = block_desc->GetVar<cpp::VarDesc>(var_idx);
for (auto& op : program.ops()) { if (!var_desc->Persistable()) {
auto kernel_type = op->op_info()->GetAttr<std::string>(kKernelTypeAttr); exe_scope->Var(var_desc->Name());
std::string op_type, alias; } else {
Place place; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); scope_->Var(var_desc->Name());
auto kernels = op->CreateKernels({place}); }
// filter out a kernel
auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& it) {
return it->alias() == alias;
});
CHECK(it != kernels.end());
#ifdef LITE_WITH_OPENCL
if ((*it)->target() == TARGET(kOpenCL)) {
std::unique_ptr<KernelContext> ctx(new KernelContext());
(*local_ctx).As<OpenCLContext>().CopySharedTo(&ctx->As<OpenCLContext>());
(*it)->SetContext(std::move(ctx));
} else {
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
} }
#else
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
#endif
insts.emplace_back(op, std::move(*it));
} }
program_.reset(new RuntimeProgram(std::move(insts))); // Only extracting the ops and generate the runtime program from the main
// block desc
CHECK(program.exec_scope()); program_.reset(new RuntimeProgram(program_desc, exe_scope, kRootBlockIdx));
program_->set_exec_scope(program.exec_scope());
} }
void LightPredictor::DequantizeWeight() { void LightPredictor::DequantizeWeight() {
std::shared_ptr<const cpp::ProgramDesc> program_desc = program_desc_;
#define PROCESS_CONV2D_DATA() \ #define PROCESS_CONV2D_DATA() \
for (int64_t i = 0; i < ch; ++i) { \ for (int64_t i = 0; i < ch; ++i) { \
for (int64_t j = 0; j < offset; ++j) { \ for (int64_t j = 0; j < offset; ++j) { \
...@@ -205,10 +188,9 @@ void LightPredictor::DequantizeWeight() { ...@@ -205,10 +188,9 @@ void LightPredictor::DequantizeWeight() {
} }
return result; return result;
}; };
Tensor tmp_tensor; Tensor tmp_tensor;
for (size_t i = 0; i < cpp_program_desc_.BlocksSize(); i++) { for (size_t i = 0; i < program_desc->BlocksSize(); i++) {
auto* block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(i); auto* block = program_desc->GetBlock<cpp::BlockDesc>(i);
for (size_t k = 0; k < block->OpsSize(); ++k) { for (size_t k = 0; k < block->OpsSize(); ++k) {
auto* op_desc = block->GetOp<cpp::OpDesc>(k); auto* op_desc = block->GetOp<cpp::OpDesc>(k);
if (is_weight_quantized_op(op_desc)) { if (is_weight_quantized_op(op_desc)) {
......
...@@ -46,6 +46,7 @@ class LITE_API LightPredictor { ...@@ -46,6 +46,7 @@ class LITE_API LightPredictor {
LightPredictor(const std::string& lite_model_file, LightPredictor(const std::string& lite_model_file,
bool model_from_memory = false) { bool model_from_memory = false) {
scope_ = std::make_shared<Scope>(); scope_ = std::make_shared<Scope>();
program_desc_ = std::make_shared<cpp::ProgramDesc>();
Build(lite_model_file, model_from_memory); Build(lite_model_file, model_from_memory);
} }
...@@ -57,6 +58,7 @@ class LITE_API LightPredictor { ...@@ -57,6 +58,7 @@ class LITE_API LightPredictor {
lite_api::LiteModelType model_type = lite_api::LiteModelType model_type =
lite_api::LiteModelType::kNaiveBuffer) { lite_api::LiteModelType::kNaiveBuffer) {
scope_ = std::make_shared<Scope>(); scope_ = std::make_shared<Scope>();
program_desc_ = std::make_shared<cpp::ProgramDesc>();
Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory); Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory);
} }
...@@ -78,6 +80,7 @@ class LITE_API LightPredictor { ...@@ -78,6 +80,7 @@ class LITE_API LightPredictor {
std::vector<std::string> GetInputNames(); std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames(); std::vector<std::string> GetOutputNames();
void PrepareFeedFetch(); void PrepareFeedFetch();
Scope* scope() { return scope_.get(); }
private: private:
void Build(const std::string& lite_model_file, void Build(const std::string& lite_model_file,
...@@ -91,14 +94,15 @@ class LITE_API LightPredictor { ...@@ -91,14 +94,15 @@ class LITE_API LightPredictor {
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool model_from_memory = false); bool model_from_memory = false);
void BuildRuntimeProgram(const cpp::ProgramDesc& prog); void BuildRuntimeProgram(
const std::shared_ptr<const cpp::ProgramDesc>& program_desc);
void DequantizeWeight(); void DequantizeWeight();
private: private:
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_; std::unique_ptr<RuntimeProgram> program_;
cpp::ProgramDesc cpp_program_desc_; std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
}; };
......
...@@ -38,8 +38,10 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { ...@@ -38,8 +38,10 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) {
threads_ = config.threads(); threads_ = config.threads();
#ifdef LITE_WITH_NPU #ifdef LITE_WITH_NPU
// Store the model-level configuration into scope for kernels, and use
// exe_scope to store the execution-level configuration
Context<TargetType::kNPU>::SetSubgraphModelCacheDir( Context<TargetType::kNPU>::SetSubgraphModelCacheDir(
config.subgraph_model_cache_dir()); raw_predictor_->scope(), config.subgraph_model_cache_dir());
#endif #endif
} }
......
...@@ -28,6 +28,7 @@ USE_MIR_PASS(graph_visualize_pass); ...@@ -28,6 +28,7 @@ USE_MIR_PASS(graph_visualize_pass);
USE_MIR_PASS(remove_tf_redundant_ops_pass); USE_MIR_PASS(remove_tf_redundant_ops_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_conv_conv_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass); USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass); USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
...@@ -53,6 +54,7 @@ USE_MIR_PASS(mlu_postprocess_pass); ...@@ -53,6 +54,7 @@ USE_MIR_PASS(mlu_postprocess_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass); USE_MIR_PASS(weight_quantization_preprocess_pass);
USE_MIR_PASS(apu_subgraph_pass); USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass)
USE_MIR_PASS(lite_scale_activation_fuse_pass); USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass); USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass);
......
...@@ -234,7 +234,7 @@ void beam_search(const Tensor *pre_ids, ...@@ -234,7 +234,7 @@ void beam_search(const Tensor *pre_ids,
selected_ids->Resize(dims); selected_ids->Resize(dims);
selected_scores->Resize(dims); selected_scores->Resize(dims);
if (parent_idx) { if (parent_idx) {
parent_idx->Resize(dims); parent_idx->Resize({static_cast<int64_t>(num_instances)});
} }
auto *selected_ids_data = selected_ids->mutable_data<int64_t>(); auto *selected_ids_data = selected_ids->mutable_data<int64_t>();
auto *selected_scores_data = selected_scores->mutable_data<float>(); auto *selected_scores_data = selected_scores->mutable_data<float>();
......
...@@ -139,6 +139,71 @@ static bool conv_trans_weights_numc(const dtype* din, ...@@ -139,6 +139,71 @@ static bool conv_trans_weights_numc(const dtype* din,
} }
return true; return true;
} }
// for example: m = 4, n = 4
// din = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9 , 10 ,11], [12, 13, 14, 15]]
// dout = [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]
/*
m = 8 n = 8: 0 1 2 3 4 5 6 7 0 8 16 24 32 40 48 56
16 17 18 19 20 21 22 23 2 10 18 26 34 42 50 58
24 25 26 27 28 29 30 31 3 11 19 27 35 43 51 59
32 33 34 35 36 37 38 39 4 12 20 28 36 44 52 60 ...
}
}
*/
template <typename Dtype>
void local_transpose(const Dtype* din, Dtype* dout, int m, int n) {
// n % 4 == 0 && m % 4 == 0
// n * m ==> n * m data trans
int offset_m = m << 2;
const Dtype* din_ptr = din;
Dtype* dout_ptr = dout;
for (int i = 0; i < n; i += 4) {
Dtype* out_ptr0 = dout_ptr;
Dtype* out_ptr1 = dout_ptr + m;
Dtype* out_ptr2 = out_ptr1 + m;
Dtype* out_ptr3 = out_ptr2 + m;
const Dtype* in_ptr0 = din_ptr;
const Dtype* in_ptr1 = din_ptr + m;
const Dtype* in_ptr2 = in_ptr1 + m;
const Dtype* in_ptr3 = in_ptr2 + m;
for (int j = 0; j < m; j += 4) {
float32x4_t vin0 = vld1q_f32(in_ptr0);
float32x4_t vin1 = vld1q_f32(in_ptr1);
float32x4_t vin2 = vld1q_f32(in_ptr2);
float32x4_t vin3 = vld1q_f32(in_ptr3);
// a00 b00 a02 b02 a01 b01 a03 b03
float32x4x2_t tmp0 = vtrnq_f32(vin0, vin1);
// c00 d00 c02 d02 c01 d01 c03 d03
float32x4x2_t tmp2 = vtrnq_f32(vin2, vin3);
in_ptr0 = in_ptr3 + m;
in_ptr1 = in_ptr3 + 2 * m;
float tmp_val1 = tmp0.val[0][2];
float tmp_val2 = tmp0.val[0][3];
tmp0.val[0][2] = tmp2.val[0][0];
tmp0.val[0][3] = tmp2.val[0][1];
float tmp_val3 = tmp0.val[1][2];
float tmp_val4 = tmp0.val[1][3];
tmp2.val[0][0] = tmp_val1;
tmp2.val[0][1] = tmp_val2;
tmp0.val[1][2] = tmp2.val[1][0];
tmp0.val[1][3] = tmp2.val[1][1];
tmp2.val[1][0] = tmp_val3;
tmp2.val[1][1] = tmp_val4;
in_ptr2 = in_ptr1 + m;
in_ptr3 = in_ptr1 + 2 * m;
vst1q_f32(out_ptr0, tmp0.val[0]);
vst1q_f32(out_ptr1, tmp0.val[1]);
out_ptr0 += 4;
out_ptr1 += 4;
vst1q_f32(out_ptr2, tmp2.val[0]);
vst1q_f32(out_ptr3, tmp2.val[1]);
out_ptr2 += 4;
out_ptr3 += 4;
}
dout_ptr += offset_m;
din_ptr += 4;
}
}
template <typename Dtype> template <typename Dtype>
void transpose(const Dtype* din, Dtype* dout, int m, int n) { void transpose(const Dtype* din, Dtype* dout, int m, int n) {
// nxm == mxn // nxm == mxn
......
...@@ -747,6 +747,16 @@ void elementwise_mul<int>(const int* dinx, ...@@ -747,6 +747,16 @@ void elementwise_mul<int>(const int* dinx,
} }
} }
template <>
void elementwise_mul<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int num) {
for (int i = 0; i < num; i++) {
dout[i] = dinx[i] * diny[i];
}
}
template <> template <>
void elementwise_mul_relu<float>(const float* dinx, void elementwise_mul_relu<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -801,6 +811,17 @@ void elementwise_mul_relu<float>(const float* dinx, ...@@ -801,6 +811,17 @@ void elementwise_mul_relu<float>(const float* dinx,
} }
} }
template <>
void elementwise_mul_relu<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int num) {
for (int i = 0; i < num; i++) {
int64_t tmp = dinx[i] * diny[i];
dout[i] = tmp > 0 ? tmp : 0;
}
}
template <> template <>
void elementwise_mul_broadcast<float>(const float* dinx, void elementwise_mul_broadcast<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -935,6 +956,29 @@ void elementwise_mul_broadcast<int>(const int* dinx, ...@@ -935,6 +956,29 @@ void elementwise_mul_broadcast<int>(const int* dinx,
} }
} }
template <>
void elementwise_mul_broadcast<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int batch,
int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const int64_t* dinx_ptr = dinx + offset;
const int64_t diny_data = diny[j];
int64_t* dout_ptr = dout + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *dinx_ptr * diny_data;
dout_ptr++;
dinx_ptr++;
}
}
}
}
template <> template <>
void elementwise_mul_relu_broadcast<float>(const float* dinx, void elementwise_mul_relu_broadcast<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -1014,6 +1058,30 @@ void elementwise_mul_relu_broadcast<float>(const float* dinx, ...@@ -1014,6 +1058,30 @@ void elementwise_mul_relu_broadcast<float>(const float* dinx,
} }
} }
template <>
void elementwise_mul_relu_broadcast<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int batch,
int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const int64_t* dinx_ptr = dinx + offset;
const int64_t diny_data = diny[j];
int64_t* dout_ptr = dout + offset;
for (int k = 0; k < num; ++k) {
int64_t tmp = *dinx_ptr * diny_data;
*dout_ptr = tmp > 0 ? tmp : 0;
dout_ptr++;
dinx_ptr++;
}
}
}
}
template <> template <>
void elementwise_max<float>(const float* dinx, void elementwise_max<float>(const float* dinx,
const float* diny, const float* diny,
......
...@@ -21,7 +21,7 @@ namespace lite { ...@@ -21,7 +21,7 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
const int MALLOC_ALIGN = 64; const int MALLOC_ALIGN = 16;
void* fast_malloc(size_t size) { void* fast_malloc(size_t size) {
size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; size_t offset = sizeof(void*) + MALLOC_ALIGN - 1;
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
const int MALLOC_ALIGN = 64; const int MALLOC_ALIGN = 16;
void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) { void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) {
size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; size_t offset = sizeof(void*) + MALLOC_ALIGN - 1;
...@@ -30,7 +30,6 @@ void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) { ...@@ -30,7 +30,6 @@ void* TargetWrapper<TARGET(kHost)>::Malloc(size_t size) {
void* r = reinterpret_cast<void*>(reinterpret_cast<size_t>(p + offset) & void* r = reinterpret_cast<void*>(reinterpret_cast<size_t>(p + offset) &
(~(MALLOC_ALIGN - 1))); (~(MALLOC_ALIGN - 1)));
static_cast<void**>(r)[-1] = p; static_cast<void**>(r)[-1] = p;
memset(r, 0, size);
return r; return r;
} }
void TargetWrapper<TARGET(kHost)>::Free(void* ptr) { void TargetWrapper<TARGET(kHost)>::Free(void* ptr) {
......
...@@ -33,7 +33,7 @@ std::shared_ptr<hiai::AiModelMngerClient> Device::Load( ...@@ -33,7 +33,7 @@ std::shared_ptr<hiai::AiModelMngerClient> Device::Load(
// Check HiAI DDK version // Check HiAI DDK version
const char* ddk_version = model_client->GetVersion(); const char* ddk_version = model_client->GetVersion();
if (ddk_version) { if (ddk_version) {
LOG(INFO) << "[NPU] HiAI DDK version: " << ddk_version; VLOG(3) << "[NPU] HiAI DDK version: " << ddk_version;
} else { } else {
LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!"; LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!";
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/backends/xpu/target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -82,8 +82,8 @@ void DumpXPUMem(const T* ptr, ...@@ -82,8 +82,8 @@ void DumpXPUMem(const T* ptr,
size_t item_per_line = 30) { size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride; size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> cpu_mem(new T[len]); std::unique_ptr<T[]> cpu_mem(new T[len]);
xpu_memcpy( XPU_CALL(xpu_memcpy(
cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST); cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST));
std::unique_ptr<T[]> after_stride(new T[after_stride_len]); std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) { for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = cpu_mem[i * stride]; after_stride[i] = cpu_mem[i * stride];
......
...@@ -19,11 +19,11 @@ namespace lite { ...@@ -19,11 +19,11 @@ namespace lite {
void* TargetWrapperXPU::Malloc(size_t size) { void* TargetWrapperXPU::Malloc(size_t size) {
void* ptr{nullptr}; void* ptr{nullptr};
xpu_malloc(&ptr, size); XPU_CALL(xpu_malloc(&ptr, size));
return ptr; return ptr;
} }
void TargetWrapperXPU::Free(void* ptr) { xpu_free(ptr); } void TargetWrapperXPU::Free(void* ptr) { XPU_CALL(xpu_free(ptr)); }
void TargetWrapperXPU::MemcpySync(void* dst, void TargetWrapperXPU::MemcpySync(void* dst,
const void* src, const void* src,
...@@ -31,10 +31,10 @@ void TargetWrapperXPU::MemcpySync(void* dst, ...@@ -31,10 +31,10 @@ void TargetWrapperXPU::MemcpySync(void* dst,
IoDirection dir) { IoDirection dir) {
switch (dir) { switch (dir) {
case IoDirection::HtoD: case IoDirection::HtoD:
xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE); XPU_CALL(xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE));
break; break;
case IoDirection::DtoH: case IoDirection::DtoH:
xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST); XPU_CALL(xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST));
break; break;
default: default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir); LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
...@@ -49,7 +49,7 @@ XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size, ...@@ -49,7 +49,7 @@ XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size,
} else { } else {
ptr = TargetWrapperXPU::Malloc(size); ptr = TargetWrapperXPU::Malloc(size);
} }
CHECK(ptr != nullptr); CHECK(ptr != nullptr) << "size = " << size << ", use_l3 = " << use_l3;
return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3)); return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3));
} }
......
...@@ -16,11 +16,23 @@ ...@@ -16,11 +16,23 @@
#include <memory> // std::unique_ptr #include <memory> // std::unique_ptr
#include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free #include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h" // TargetWrapper
#include "lite/utils/cp_logging.h" // CHECK_EQ
#define XPU_CALL(func) \
{ \
auto e = (func); \
CHECK_EQ(e, 0) << "XPU: (" << #func << ") returns " << e; \
}
namespace paddle { namespace paddle {
namespace lite { namespace lite {
// MAX(lod.size()) = 64
const int XPU_MAX_LOD_SIZE = 64;
// MAX(lod[i + 1] - lod[i]) = 512
const int XPU_MAX_LOD_SEQ_LEN = 512;
using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>; using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;
struct XPUScratchPad { struct XPUScratchPad {
...@@ -33,7 +45,7 @@ struct XPUScratchPad { ...@@ -33,7 +45,7 @@ struct XPUScratchPad {
struct XPUScratchPadDeleter { struct XPUScratchPadDeleter {
void operator()(XPUScratchPad* sp) const { void operator()(XPUScratchPad* sp) const {
if (!sp->is_l3_) { if (!sp->is_l3_) {
xpu_free(sp->addr_); XPU_CALL(xpu_free(sp->addr_));
} }
delete sp; delete sp;
} }
...@@ -55,7 +67,7 @@ class TargetWrapper<TARGET(kXPU)> { ...@@ -55,7 +67,7 @@ class TargetWrapper<TARGET(kXPU)> {
size_t size, size_t size,
IoDirection dir); IoDirection dir);
static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = true); static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = false);
static xdnn::Context* GetRawContext() { static xdnn::Context* GetRawContext() {
if (tls_raw_ctx_ == nullptr) { if (tls_raw_ctx_ == nullptr) {
...@@ -77,11 +89,10 @@ class TargetWrapper<TARGET(kXPU)> { ...@@ -77,11 +89,10 @@ class TargetWrapper<TARGET(kXPU)> {
static void SetDev(int dev_no = 0) { static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV"); const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) { if (dev_env) {
xpu_set_device(atoi(dev_env)); dev_no = atoi(dev_env);
return;
} }
xpu_set_device(dev_no); XPU_CALL(xpu_set_device(dev_no));
} }
static std::string multi_encoder_precision; // NOLINT static std::string multi_encoder_precision; // NOLINT
......
...@@ -32,25 +32,27 @@ void TestCase::CreateInstruction() { ...@@ -32,25 +32,27 @@ void TestCase::CreateInstruction() {
#endif #endif
if (enable_subgraph_op) { if (enable_subgraph_op) {
// Create a new block desc to wrap the original op desc // Create a new block desc to wrap the original op desc
auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
int sub_block_idx = 0; int sub_block_idx = 0;
auto sub_block_desc = new cpp::BlockDesc(); auto sub_block_desc = sub_program_desc->AddBlock<cpp::BlockDesc>();
sub_block_desc->ClearOps(); sub_block_desc->ClearOps();
sub_block_desc->ClearVars(); sub_block_desc->ClearVars();
auto sub_block_op_desc = sub_block_desc->AddOp<cpp::OpDesc>(); auto sub_op_desc = sub_block_desc->AddOp<cpp::OpDesc>();
*sub_block_op_desc = *op_desc_; *sub_op_desc = *op_desc_;
// Add the block desc into the subgraph op which used to replace the // Add the block desc into the subgraph op which used to replace the
// original op // original op
op_desc_.reset(new cpp::OpDesc()); op_desc_.reset(new cpp::OpDesc());
op_desc_->SetType("subgraph"); op_desc_->SetType("subgraph");
op_desc_->SetAttr<int32_t>("sub_block", sub_block_idx); op_desc_->SetAttr<int32_t>("sub_block", sub_block_idx);
auto in_names = sub_block_op_desc->input_vars(); auto in_names = sub_op_desc->input_vars();
auto out_names = sub_block_op_desc->output_vars(); auto out_names = sub_op_desc->output_vars();
op_desc_->SetInput("Inputs", in_names); op_desc_->SetInput("Inputs", in_names);
op_desc_->SetOutput("Outputs", out_names); op_desc_->SetOutput("Outputs", out_names);
op_desc_->SetAttr<std::vector<std::string>>("input_data_names", in_names); op_desc_->SetAttr<std::vector<std::string>>("input_data_names", in_names);
op_desc_->SetAttr<std::vector<std::string>>("output_data_names", out_names); op_desc_->SetAttr<std::vector<std::string>>("output_data_names", out_names);
op = LiteOpRegistry::Global().Create(op_desc().Type()); op = LiteOpRegistry::Global().Create(op_desc().Type());
static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock(sub_block_desc); static_cast<operators::SubgraphOp*>(op.get())->SetProgramDesc(
sub_program_desc);
} else { } else {
op = LiteOpRegistry::Global().Create(op_desc().Type()); op = LiteOpRegistry::Global().Create(op_desc().Type());
} }
...@@ -60,7 +62,7 @@ void TestCase::CreateInstruction() { ...@@ -60,7 +62,7 @@ void TestCase::CreateInstruction() {
// filter out the target kernel // filter out the target kernel
CHECK(!kernels.empty()) << "No kernel found for place " CHECK(!kernels.empty()) << "No kernel found for place "
<< place_.DebugString(); << place_.DebugString();
auto it = std::remove_if( auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& k) { kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& k) {
return k->alias() == alias_; return k->alias() == alias_;
}); });
...@@ -234,19 +236,6 @@ bool TestCase::CheckPrecision(const std::string& var_name, ...@@ -234,19 +236,6 @@ bool TestCase::CheckPrecision(const std::string& var_name,
return success; return success;
} }
TestCase::~TestCase() {
if (op_desc_->Type() == "subgraph") {
// Release the subblock desc of Subgraph op
auto subgraph_op = const_cast<operators::SubgraphOp*>(
static_cast<const operators::SubgraphOp*>(instruction_->op()));
CHECK(subgraph_op);
auto sub_block_desc = subgraph_op->GetSubBlock();
if (sub_block_desc) {
delete sub_block_desc;
}
}
}
} // namespace arena } // namespace arena
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -46,7 +46,7 @@ class TestCase { ...@@ -46,7 +46,7 @@ class TestCase {
base_scope_(new Scope) { base_scope_(new Scope) {
ctx_ = ContextScheduler::Global().NewContext(place_.target); ctx_ = ContextScheduler::Global().NewContext(place_.target);
} }
virtual ~TestCase(); virtual ~TestCase() {}
void Prepare() { void Prepare() {
PrepareData(); PrepareData();
......
...@@ -17,10 +17,6 @@ ...@@ -17,10 +17,6 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
#ifdef LITE_WITH_NPU
std::string Context<TargetType::kNPU>::subgraph_model_cache_dir_{""}; // NOLINT
#endif
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
int Context<TargetType::kMLU>::next_queue_id_{0}; int Context<TargetType::kMLU>::next_queue_id_{0};
std::map<int, int> Context<TargetType::kMLU>::queue_id_map_; std::map<int, int> Context<TargetType::kMLU>::queue_id_map_;
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/device_info.h" #include "lite/core/device_info.h"
#include "lite/core/scope.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
...@@ -84,15 +85,19 @@ class Context<TargetType::kNPU> { ...@@ -84,15 +85,19 @@ class Context<TargetType::kNPU> {
NPUContext& operator=(const NPUContext& ctx) {} NPUContext& operator=(const NPUContext& ctx) {}
std::string name() const { return "NPUContext"; } std::string name() const { return "NPUContext"; }
static void SetSubgraphModelCacheDir(std::string subgraph_model_cache_dir) { static void SetSubgraphModelCacheDir(Scope* scope,
subgraph_model_cache_dir_ = subgraph_model_cache_dir; std::string subgraph_model_cache_dir) {
auto var = scope->Var("SUBGRAPH_MODEL_CACHE_DIR");
CHECK(var);
auto data = var->GetMutable<std::string>();
CHECK(data);
*data = subgraph_model_cache_dir;
} }
static std::string SubgraphModelCacheDir() { static std::string SubgraphModelCacheDir(Scope* scope) {
return subgraph_model_cache_dir_; auto var = scope->FindVar("SUBGRAPH_MODEL_CACHE_DIR");
if (!var) return "";
return var->Get<std::string>();
} }
private:
static std::string subgraph_model_cache_dir_;
}; };
#endif #endif
......
...@@ -18,6 +18,7 @@ lite_cc_library(mir_passes ...@@ -18,6 +18,7 @@ lite_cc_library(mir_passes
fusion/conv_activation_fuse_pass.cc fusion/conv_activation_fuse_pass.cc
fusion/var_conv_2d_activation_fuse_pass.cc fusion/var_conv_2d_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc fusion/conv_bn_fuse_pass.cc
fusion/conv_conv_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc fusion/sequence_pool_concat_fuse_pass.cc
...@@ -32,6 +33,7 @@ lite_cc_library(mir_passes ...@@ -32,6 +33,7 @@ lite_cc_library(mir_passes
elimination/identity_dropout_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc
elimination/remove_tf_redundant_ops_pass.cc elimination/remove_tf_redundant_ops_pass.cc
elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.cc
static_kernel_pick_pass.cc static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_cast_pass.cc type_target_cast_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
// Remove all of the unused nodes from the contorl flow op and update the inputs
// and outputs of the op info The unused nodes are defined as the nodes which
// are only linked to the control flow op nodes but nerver linked to the other
// op nodes.
//
// For example:
// graph[0]: main block
// in_x
// in_f | in_z(unused node)
// \ | /
// \ | /
// in_w ------- while ------- in_y(unused_node)
// / |
// / |
// (unused node)out_y |
// out_x
//
// graph[1]: sub block
// in_x
// |
// |
// conv2d----in_f
// |
// |
// fc ------in_w
// |
// |
// softmax
// |
// |
// out_x
//
// After the pass is applied:
// in_x
// in_f |
// \ |
// \ |
// in_w ------- while
// |
// |
// |
// out_x
// Remove the var node from var2rm if it is recursively referred to any op in
// the subblock
void CollectUnusedInputOutputNodes(
int block_idx,
std::vector<std::unique_ptr<mir::SSAGraph>>* graphs,
const std::unordered_set<std::string>& control_flow_op_types,
std::unordered_map<std::string, Node*>* in_vars2rm,
std::unordered_map<std::string, Node*>* out_vars2rm) {
auto block_size = graphs->size();
for (auto& op_node : (*graphs)[block_idx]->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().op_info();
auto op_type = op_info->Type();
if (control_flow_op_types.count(op_type)) {
int sub_block_idx = op_info->GetAttr<int32_t>("sub_block");
CHECK(block_idx >= 0 && block_idx < block_size);
CollectUnusedInputOutputNodes(sub_block_idx,
graphs,
control_flow_op_types,
in_vars2rm,
out_vars2rm);
} else {
for (auto& var_node : op_node->inlinks) {
auto& var_name = var_node->AsArg().name;
if (in_vars2rm->count(var_name)) {
in_vars2rm->erase(var_name);
}
}
for (auto& var_node : op_node->outlinks) {
auto& var_name = var_node->AsArg().name;
// Tensor array may be only used as the output vars in the sublock
if (in_vars2rm->count(var_name)) {
in_vars2rm->erase(var_name);
}
if (out_vars2rm->count(var_name)) {
out_vars2rm->erase(var_name);
}
}
}
}
}
// Remove the unused var nodes from the graph and update the op_info of the
// control flow op
void RemoveNodesFromGraphAndUpdateOpInfo(
SSAGraph* graph,
Node* op_node,
const std::unordered_map<std::string, Node*>& in_vars2rm,
const std::unordered_map<std::string, Node*>& out_vars2rm) {
auto op_info = op_node->AsStmt().mutable_op_info();
auto op_type = op_info->Type();
// Unlink the in_vars2rm and out_vars2rm from the control flow op node, and
// remove them if nerver used.
for (auto& var_node : in_vars2rm) {
VLOG(3) << "in var node '" << var_node.first << "' is unlinked to "
<< op_type;
RemoveDirectedLink(var_node.second, op_node);
}
for (auto& var_node : out_vars2rm) {
VLOG(3) << "out var node '" << var_node.first << "' is unlinked from "
<< op_type;
RemoveDirectedLink(op_node, var_node.second);
// Unlink from all of the out op nodes.
std::unordered_set<Node*> out_op_nodes;
for (auto* out_op_node : var_node.second->outlinks) {
if (!out_op_nodes.count(out_op_node)) {
out_op_nodes.insert(out_op_node);
}
}
for (auto* out_op_node : out_op_nodes) {
RemoveDirectedLink(var_node.second, out_op_node);
}
}
// Remove the unused nodes from the graph if their inlinks and outlinks are
// empty
std::unordered_set<const Node*> removed_var_nodes;
for (auto& var_node : in_vars2rm) {
if (var_node.second->inlinks.empty() && var_node.second->outlinks.empty() &&
!removed_var_nodes.count(var_node.second)) {
removed_var_nodes.insert(var_node.second);
graph->RemoveNode(var_node.second);
VLOG(3) << "in var node " << var_node.first << " is removed";
}
}
for (auto& var_node : out_vars2rm) {
if (var_node.second->inlinks.empty() && var_node.second->outlinks.empty() &&
!removed_var_nodes.count(var_node.second)) {
removed_var_nodes.insert(var_node.second);
graph->RemoveNode(var_node.second);
VLOG(3) << "out var node " << var_node.first << " is removed";
}
}
// Update the op info of the control flow op
for (auto& input : *op_info->mutable_inputs()) {
for (auto var = input.second.begin(); var != input.second.end();) {
if (in_vars2rm.count(*var)) {
var = input.second.erase(var);
} else {
++var;
}
}
}
for (auto& output : *op_info->mutable_outputs()) {
for (auto var = output.second.begin(); var != output.second.end();) {
if (out_vars2rm.count(*var)) {
var = output.second.erase(var);
} else {
++var;
}
}
}
}
void ControlFlowOpUnusedInputsAndOutputsEliminatePass::SetAllGraphs(
std::vector<std::unique_ptr<mir::SSAGraph>>* graphs) {
CHECK(graphs && !graphs->empty());
graphs_ = graphs;
}
void ControlFlowOpUnusedInputsAndOutputsEliminatePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
// Remove the unused input and output nodes from the control flow op nodes
// Which are only linked to the control flow op nodes but nerver linked to the
// other op nodes
const std::unordered_set<std::string> control_flow_op_types = {
"while", "conditional_block"};
auto block_size = graphs_->size();
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().mutable_op_info();
auto op_type = op_info->Type();
if (!control_flow_op_types.count(op_type)) continue;
int sub_block_idx = op_info->GetAttr<int32_t>("sub_block");
CHECK(sub_block_idx >= 0 && sub_block_idx < block_size);
// Initialize the unused nodes with all of the input and output nodes
std::unordered_map<std::string, Node *> in_vars2rm, out_vars2rm;
for (auto* var_node : op_node->inlinks) {
auto& var_name = var_node->AsArg().name;
if (!in_vars2rm.count(var_name)) {
in_vars2rm.insert(std::pair<std::string, Node*>(var_name, var_node));
}
}
for (auto* var_node : op_node->outlinks) {
auto& var_name = var_node->AsArg().name;
if (!out_vars2rm.count(var_name)) {
out_vars2rm.insert(std::pair<std::string, Node*>(var_name, var_node));
}
}
// Remove the nodes which used in subblock recursively, and the remaining
// nodes are the unused one.
CollectUnusedInputOutputNodes(sub_block_idx,
graphs_,
control_flow_op_types,
&in_vars2rm,
&out_vars2rm);
if (in_vars2rm.size() > 0 || out_vars2rm.size() > 0) {
// Remove the unused nodes from graph, and update the op info of the
// control flow op
RemoveNodesFromGraphAndUpdateOpInfo(
graph.get(), op_node, in_vars2rm, out_vars2rm);
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(
control_flow_op_unused_inputs_and_outputs_eliminate_pass,
paddle::lite::mir::ControlFlowOpUnusedInputsAndOutputsEliminatePass)
.BindTargets({TARGET(kNPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace mir {
class ControlFlowOpUnusedInputsAndOutputsEliminatePass : public mir::StmtPass {
public:
void Apply(const std::unique_ptr<SSAGraph> &graph) override;
void SetAllGraphs(std::vector<std::unique_ptr<mir::SSAGraph>> *graphs);
private:
std::vector<std::unique_ptr<mir::SSAGraph>> *graphs_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -16,6 +16,9 @@ lite_cc_library(fuse_var_conv_activation ...@@ -16,6 +16,9 @@ lite_cc_library(fuse_var_conv_activation
lite_cc_library(fuse_conv_bn lite_cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
lite_cc_library(fuse_conv_conv
SRCS conv_conv_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_elementwise_add_activation lite_cc_library(fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
...@@ -42,6 +45,7 @@ set(mir_fusers ...@@ -42,6 +45,7 @@ set(mir_fusers
fuse_conv_activation fuse_conv_activation
fuse_var_conv_activation fuse_var_conv_activation
fuse_conv_bn fuse_conv_bn
fuse_conv_conv
fuse_quant_dequant fuse_quant_dequant
fuse_elementwise_add_activation fuse_elementwise_add_activation
fuse_transpose_softmax_transpose fuse_transpose_softmax_transpose
......
...@@ -326,6 +326,28 @@ class XPUMmdnnSearchAttentionFuser : public FuseBase { ...@@ -326,6 +326,28 @@ class XPUMmdnnSearchAttentionFuser : public FuseBase {
} }
}; };
// 4 inputs
// ========
//
// input_x
// input_y
// topk_row
// topk_col
//
// input_x ------- match_matrix_tensor ------- input_y
// |
// relu
// ________/ \________
// | |
// var_conv_2d |
// | |
// relu |
// |_______ _______|
// \ /
// sequence_concat
// |
// topk_row ---- sequence_topk_avg_pooling ----- topk_col
//
class XPUMmdnnMatchConvTopkFuser : public FuseBase { class XPUMmdnnMatchConvTopkFuser : public FuseBase {
public: public:
void BuildPattern() override { void BuildPattern() override {
...@@ -418,10 +440,156 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase { ...@@ -418,10 +440,156 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase {
auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info(); auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info();
op_desc.SetAttr<float>("input_w_max", op_desc.SetAttr<float>("input_w_max",
match_op_info->GetAttr<float>("w_max")); match_op_info->GetAttr<float>("__xpu__w_max"));
op_desc.SetAttr<int>("dim_t", match_op_info->GetAttr<int>("dim_t"));
auto* conv_op_info = matched.at("conv")->stmt()->op_info();
op_desc.SetAttr<float>("conv_w_max",
conv_op_info->GetAttr<float>("__xpu__w_max"));
op_desc.SetAttr<int>("output_channel",
conv_op_info->GetAttr<int>("OutputChannel"));
auto* topk_op_info = matched.at("topk")->stmt()->op_info();
op_desc.SetAttr<std::vector<int>>(
"topks", topk_op_info->GetAttr<std::vector<int>>("topks"));
op_desc.SetAttr<int>("channel_num",
topk_op_info->GetAttr<int>("channel_num"));
auto* new_stmt = matched.at("match_matrix_tensor")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
// XXX(miaotianxiang): redundant links around |topk| are automatically
// removed as |topk| is marked intermediate.
// RemoveDirectedLink(matched.at("topk_col"), matched.at("topk"));
// RemoveDirectedLink(matched.at("topk_row"), matched.at("topk"));
std::vector<std::string> arg_names{"conv_w"};
for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("match_matrix_tensor"));
}
std::vector<std::string> out_names{"topk_out"};
for (auto name : out_names) {
IR_OP_VAR_LINK(matched.at("match_matrix_tensor"), matched.at(name));
}
}
};
// 2 inputs
// ========
//
// input_x
// input_y
//
// input_x ------- match_matrix_tensor ------- input_y
// | | |
// | relu |
// | ________/ \________ |
// | | | |
// | var_conv_2d | |
// | | | |
// | relu | |
// | |_______ _______| |
// | \ / |
// | sequence_concat |
// | | |
// |--------- sequence_topk_avg_pooling -------|
//
class XPUMmdnnMatchConvTopkFuser2 : public FuseBase {
public:
void BuildPattern() override {
auto* input_x = VarNode("input_x")
->assert_is_op_input("match_matrix_tensor", "X")
->assert_is_op_input("sequence_topk_avg_pooling", "ROW")
->AsInput();
auto* input_y =
VarNode("input_y")
->assert_is_op_input("match_matrix_tensor", "Y")
->assert_is_op_input("sequence_topk_avg_pooling", "COLUMN")
->AsInput();
auto* input_w = VarNode("input_w")
->assert_is_op_input("match_matrix_tensor", "W")
->AsInput();
auto* match_matrix_tensor =
OpNode("match_matrix_tensor", "match_matrix_tensor");
auto* match_out = VarNode("match_out")
->assert_is_op_output("match_matrix_tensor", "Out")
->AsIntermediate();
auto* match_tmp = VarNode("match_tmp")
->assert_is_op_output("match_matrix_tensor", "Tmp")
->AsIntermediate();
auto* relu0 = OpNode("relu0", "relu")->AsIntermediate();
auto* relu0_out = VarNode("relu0_out")
->assert_is_op_output("relu", "Out")
->AsIntermediate();
auto* conv_w =
VarNode("conv_w")->assert_is_op_input("var_conv_2d", "W")->AsInput();
auto* conv = OpNode("conv", "var_conv_2d")->AsIntermediate();
auto* conv_out = VarNode("conv_out")
->assert_is_op_output("var_conv_2d", "Out")
->AsIntermediate();
auto* conv_col = VarNode("conv_col")
->assert_is_op_output("var_conv_2d", "Col")
->AsIntermediate();
auto* relu1 = OpNode("relu1", "relu")->AsIntermediate();
auto* relu1_out = VarNode("relu1_out")
->assert_is_op_output("relu", "Out")
->AsIntermediate();
auto* seq_concat =
OpNode("seq_concat", "sequence_concat")->AsIntermediate();
auto* seq_concat_out =
VarNode("seq_concat_out")
->assert_is_op_output("sequence_concat", "Out")
->assert_is_op_input("sequence_topk_avg_pooling", "X")
->AsIntermediate();
auto* topk = OpNode("topk", "sequence_topk_avg_pooling")->AsIntermediate();
auto* topk_out =
VarNode("topk_out")
->assert_is_op_output("sequence_topk_avg_pooling", "Out")
->AsOutput();
auto* topk_pos =
VarNode("topk_pos")
->assert_is_op_output("sequence_topk_avg_pooling", "pos")
->AsIntermediate();
*input_x >> *match_matrix_tensor;
*input_y >> *match_matrix_tensor;
*input_w >> *match_matrix_tensor;
*match_matrix_tensor >> *match_out >> *relu0 >> *relu0_out;
*match_matrix_tensor >> *match_tmp;
*relu0_out >> *conv >> *conv_out >> *relu1 >> *relu1_out;
*conv_w >> *conv;
*conv >> *conv_col;
*relu0_out >> *seq_concat;
*relu1_out >> *seq_concat;
*seq_concat >> *seq_concat_out >> *topk >> *topk_out;
*input_x >> *topk;
*input_y >> *topk;
*topk >> *topk_pos;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_match_conv_topk");
op_desc.SetInput("input_x", {matched.at("input_x")->arg()->name});
op_desc.SetInput("input_y", {matched.at("input_y")->arg()->name});
op_desc.SetInput("input_w", {matched.at("input_w")->arg()->name});
op_desc.SetInput("conv_w", {matched.at("conv_w")->arg()->name});
op_desc.SetOutput("topk_out", {matched.at("topk_out")->arg()->name});
auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info();
op_desc.SetAttr<float>("input_w_max",
match_op_info->GetAttr<float>("__xpu__w_max"));
op_desc.SetAttr<int>("dim_t", match_op_info->GetAttr<int>("dim_t")); op_desc.SetAttr<int>("dim_t", match_op_info->GetAttr<int>("dim_t"));
auto* conv_op_info = matched.at("conv")->stmt()->op_info(); auto* conv_op_info = matched.at("conv")->stmt()->op_info();
op_desc.SetAttr<float>("conv_w_max", conv_op_info->GetAttr<float>("w_max")); op_desc.SetAttr<float>("conv_w_max",
conv_op_info->GetAttr<float>("__xpu__w_max"));
op_desc.SetAttr<int>("output_channel",
conv_op_info->GetAttr<int>("OutputChannel"));
auto* topk_op_info = matched.at("topk")->stmt()->op_info(); auto* topk_op_info = matched.at("topk")->stmt()->op_info();
op_desc.SetAttr<std::vector<int>>( op_desc.SetAttr<std::vector<int>>(
"topks", topk_op_info->GetAttr<std::vector<int>>("topks")); "topks", topk_op_info->GetAttr<std::vector<int>>("topks"));
...@@ -437,8 +605,7 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase { ...@@ -437,8 +605,7 @@ class XPUMmdnnMatchConvTopkFuser : public FuseBase {
new_stmt->SetKernels(std::move(kernels)); new_stmt->SetKernels(std::move(kernels));
// XXX(miaotianxiang): redundant links around |topk| are automatically // XXX(miaotianxiang): redundant links around |topk| are automatically
// removed as |topk| is // removed as |topk| is marked intermediate.
// marked intermediate.
// RemoveDirectedLink(matched.at("topk_col"), matched.at("topk")); // RemoveDirectedLink(matched.at("topk_col"), matched.at("topk"));
// RemoveDirectedLink(matched.at("topk_row"), matched.at("topk")); // RemoveDirectedLink(matched.at("topk_row"), matched.at("topk"));
std::vector<std::string> arg_names{"conv_w"}; std::vector<std::string> arg_names{"conv_w"};
...@@ -624,6 +791,15 @@ class XPUMmdnnBidEmbAttFuser : public FuseBase { ...@@ -624,6 +791,15 @@ class XPUMmdnnBidEmbAttFuser : public FuseBase {
} }
}; };
// 5 outputs
// =========
//
// eltwise01_out
// seq_pool_right_out
// seq_pool_left_out
// seq_pool_2in1_out
// concat_3in1_out
//
class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase {
public: public:
void BuildPattern() override { void BuildPattern() override {
...@@ -818,17 +994,272 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { ...@@ -818,17 +994,272 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase {
auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info(); auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wh_maxs", "grnn_fw_wh_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wh_max")); grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wi_maxs", "grnn_fw_wi_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wi_max")); grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info(); auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wh_maxs", "grnn_rv_wh_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wh_max")); grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wi_maxs", "grnn_rv_wi_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wi_max")); grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info();
op_desc.SetAttr<float>("att_fc_w_max",
att_fc_op_info->GetAttr<float>("W_max"));
auto* new_stmt = matched.at("emb0")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
std::vector<std::string> arg_names{
"input1",
"grnn_left_wh",
"grnn_left_wi",
"grnn_right_wh",
"grnn_right_wi",
"att_2in1_w",
"att_2in1_b",
};
for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("emb0"));
}
std::vector<std::string> out_names{
"seq_pool_left_out",
"seq_pool_right_out",
"seq_pool_2in1_out",
"concat_3in1_out",
"eltwise01_out",
};
for (auto name : out_names) {
IR_OP_VAR_LINK(matched.at("emb0"), matched.at(name));
}
}
};
// 6 outputs
// =========
//
// emb0_out
// eltwise01_out
// seq_pool_right_out
// seq_pool_left_out
// seq_pool_2in1_out
// concat_3in1_out
//
class XPUMmdnnBidEmbGrnnAttFuser2 : public FuseBase {
public:
void BuildPattern() override {
auto* input0 = VarNode("input0")->AsInput();
auto* input1 = VarNode("input1")->AsInput();
auto* emb_tbl = VarNode("emb_tbl")->AsInput();
auto* emb0 = OpNode("emb0", "lookup_table");
auto* emb0_out = VarNode("emb0_out")
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("search_seq_arithmetic", "X")
->AsOutput();
auto* emb1 = OpNode("emb1", "lookup_table")->AsIntermediate();
auto* emb1_out = VarNode("emb1_out")
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("search_seq_arithmetic", "Y")
->AsIntermediate();
auto* eltwise01 =
OpNode("eltwise01", "search_seq_arithmetic")->AsIntermediate();
auto* eltwise01_out =
VarNode("eltwise01_out")
->assert_is_op_output("search_seq_arithmetic", "Out")
->AsOutput();
auto* seq_rev_right0 =
OpNode("seq_rev_right0", "sequence_reverse")->AsIntermediate();
auto* seq_rev_right0_out =
VarNode("seq_rev_right0_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* grnn_right_wh = VarNode("grnn_right_wh")
->assert_is_op_input("search_grnn", "Wh")
->AsInput();
auto* grnn_right_wi = VarNode("grnn_right_wi")
->assert_is_op_input("search_grnn", "Wi")
->AsInput();
auto* grnn_right = OpNode("grnn_right", "search_grnn")->AsIntermediate();
auto* grnn_right_out = VarNode("grnn_right_out")
->assert_is_op_output("search_grnn", "Out")
->AsIntermediate();
auto* grnn_right_idx_sorted_by_width =
VarNode("grnn_right_idx_sorted_by_width")
->assert_is_op_output("search_grnn", "idx_sorted_by_width")
->AsIntermediate();
auto* grnn_right_layout_input =
VarNode("grnn_right_layout_input")
->assert_is_op_output("search_grnn", "layout_input")
->AsIntermediate();
auto* grnn_right_tmp_buffer =
VarNode("grnn_right_tmp_buffer")
->assert_is_op_output("search_grnn", "tmp_buffer")
->AsIntermediate();
auto* seq_rev_right1 =
OpNode("seq_rev_right1", "sequence_reverse")->AsIntermediate();
auto* seq_rev_right1_out =
VarNode("seq_rev_right1_out")
->assert_is_op_output("sequence_reverse", "Y")
->AsIntermediate();
auto* seq_pool_right =
OpNode("seq_pool_right", "sequence_pool")->AsIntermediate();
auto* seq_pool_right_out = VarNode("seq_pool_right_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_right_max_idx =
VarNode("seq_pool_right_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* grnn_left_wh = VarNode("grnn_left_wh")
->assert_is_op_input("search_grnn", "Wh")
->AsInput();
auto* grnn_left_wi = VarNode("grnn_left_wi")
->assert_is_op_input("search_grnn", "Wi")
->AsInput();
auto* grnn_left = OpNode("grnn_left", "search_grnn")->AsIntermediate();
auto* grnn_left_out = VarNode("grnn_left_out")
->assert_is_op_output("search_grnn", "Out")
->AsIntermediate();
auto* grnn_left_idx_sorted_by_width =
VarNode("grnn_left_idx_sorted_by_width")
->assert_is_op_output("search_grnn", "idx_sorted_by_width")
->AsIntermediate();
auto* grnn_left_layout_input =
VarNode("grnn_left_layout_input")
->assert_is_op_output("search_grnn", "layout_input")
->AsIntermediate();
auto* grnn_left_tmp_buffer =
VarNode("grnn_left_tmp_buffer")
->assert_is_op_output("search_grnn", "tmp_buffer")
->AsIntermediate();
auto* seq_pool_left =
OpNode("seq_pool_left", "sequence_pool")->AsIntermediate();
auto* seq_pool_left_out = VarNode("seq_pool_left_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_left_max_idx =
VarNode("seq_pool_left_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate();
auto* concat_2in1_out = VarNode("concat_2in1_out")
->assert_is_op_output("concat", "Out")
->AsIntermediate();
auto* att_2in1_w =
VarNode("att_2in1_w")
->assert_is_op_input("__xpu__mmdnn_search_attention", "W")
->AsInput();
auto* att_2in1_b =
VarNode("att_2in1_b")
->assert_is_op_input("__xpu__mmdnn_search_attention", "b")
->AsInput();
auto* att_2in1 =
OpNode("att_2in1", "__xpu__mmdnn_search_attention")->AsIntermediate();
auto* att_2in1_out =
VarNode("att_2in1_out")
->assert_is_op_output("__xpu__mmdnn_search_attention", "Out")
->AsIntermediate();
auto* seq_pool_2in1 =
OpNode("seq_pool_2in1", "sequence_pool")->AsIntermediate();
auto* seq_pool_2in1_out = VarNode("seq_pool_2in1_out")
->assert_is_op_output("sequence_pool", "Out")
->AsOutput();
auto* seq_pool_2in1_max_idx =
VarNode("seq_pool_2in1_max_idx")
->assert_is_op_output("sequence_pool", "MaxIndex")
->AsIntermediate();
auto* concat_3in1 = OpNode("concat_3in1", "concat")->AsIntermediate();
auto* concat_3in1_out = VarNode("concat_3in1_out")
->assert_is_op_output("concat", "Out")
->AsOutput();
*input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out;
*emb_tbl >> *emb0;
*input1 >> *emb1 >> *emb1_out >> *eltwise01;
*emb_tbl >> *emb1;
*eltwise01_out >> *seq_rev_right0 >> *seq_rev_right0_out >> *grnn_right >>
*grnn_right_out >> *seq_rev_right1 >> *seq_rev_right1_out;
*grnn_right_out >> *seq_pool_right >> *seq_pool_right_out;
*seq_pool_right >> *seq_pool_right_max_idx;
*grnn_right_wh >> *grnn_right;
*grnn_right_wi >> *grnn_right;
*grnn_right >> *grnn_right_idx_sorted_by_width;
*grnn_right >> *grnn_right_layout_input;
*grnn_right >> *grnn_right_tmp_buffer;
*eltwise01_out >> *grnn_left >> *grnn_left_out >> *seq_pool_left >>
*seq_pool_left_out;
*seq_pool_left >> *seq_pool_left_max_idx;
*grnn_left_wh >> *grnn_left;
*grnn_left_wi >> *grnn_left;
*grnn_left >> *grnn_left_idx_sorted_by_width;
*grnn_left >> *grnn_left_layout_input;
*grnn_left >> *grnn_left_tmp_buffer;
*seq_rev_right1_out >> *concat_2in1;
*grnn_left_out >> *concat_2in1;
*concat_2in1 >> *concat_2in1_out >> *att_2in1 >> *att_2in1_out >>
*seq_pool_2in1 >> *seq_pool_2in1_out;
*seq_pool_2in1 >> *seq_pool_2in1_max_idx;
*att_2in1_w >> *att_2in1;
*att_2in1_b >> *att_2in1;
*eltwise01_out >> *concat_3in1;
*seq_rev_right1_out >> *concat_3in1;
*grnn_left_out >> *concat_3in1;
*concat_3in1 >> *concat_3in1_out;
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__mmdnn_bid_emb_grnn_att2");
op_desc.SetInput("id0", {matched.at("input0")->arg()->name});
op_desc.SetInput("id1", {matched.at("input1")->arg()->name});
op_desc.SetInput("emb_tbl", {matched.at("emb_tbl")->arg()->name});
op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_left_wh")->arg()->name});
op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_left_wi")->arg()->name});
op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_right_wh")->arg()->name});
op_desc.SetInput("grnn_rv_wi", {matched.at("grnn_right_wi")->arg()->name});
op_desc.SetInput("att_fc_w", {matched.at("att_2in1_w")->arg()->name});
op_desc.SetInput("att_fc_b", {matched.at("att_2in1_b")->arg()->name});
op_desc.SetOutput("emb0_out", {matched.at("emb0_out")->arg()->name});
op_desc.SetOutput("grnn_fw_pool_out",
{matched.at("seq_pool_left_out")->arg()->name});
op_desc.SetOutput("grnn_rv_pool_out",
{matched.at("seq_pool_right_out")->arg()->name});
op_desc.SetOutput("att_pool_out",
{matched.at("seq_pool_2in1_out")->arg()->name});
op_desc.SetOutput("concat_3in1_out",
{matched.at("concat_3in1_out")->arg()->name});
op_desc.SetOutput("emb_fw_out", {matched.at("eltwise01_out")->arg()->name});
auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wh_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wi_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wh_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wi_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info(); auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info();
op_desc.SetAttr<float>("att_fc_w_max", op_desc.SetAttr<float>("att_fc_w_max",
att_fc_op_info->GetAttr<float>("W_max")); att_fc_op_info->GetAttr<float>("W_max"));
...@@ -868,6 +1299,9 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { ...@@ -868,6 +1299,9 @@ class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase {
class XPUMmdnnMergeAllFuser : public FuseBase { class XPUMmdnnMergeAllFuser : public FuseBase {
public: public:
explicit XPUMmdnnMergeAllFuser(int n_concat_topk)
: n_concat_topk_(n_concat_topk) {}
void BuildPattern() override { void BuildPattern() override {
auto* concat_7in1_input0 = VarNode("concat_7in1_input0") auto* concat_7in1_input0 = VarNode("concat_7in1_input0")
->assert_is_op_nth_input("concat", "X", 0) ->assert_is_op_nth_input("concat", "X", 0)
...@@ -909,16 +1343,25 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -909,16 +1343,25 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
->assert_is_op_output("relu", "Out") ->assert_is_op_output("relu", "Out")
->AsIntermediate(); ->AsIntermediate();
auto* concat_2in1_input0 = VarNode("concat_2in1_input0") auto* concat_topk_input0 = VarNode("concat_topk_input0")
->assert_is_op_nth_input("concat", "X", 0) ->assert_is_op_nth_input("concat", "X", 0)
->AsInput(); ->AsInput();
auto* concat_2in1_input1 = VarNode("concat_2in1_input1") auto* concat_topk_input1 = VarNode("concat_topk_input1")
->assert_is_op_nth_input("concat", "X", 1) ->assert_is_op_nth_input("concat", "X", 1)
->AsInput(); ->AsInput();
auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate(); auto* concat_topk = OpNode("concat_topk", "concat")->AsIntermediate();
auto* concat_2in1_out = VarNode("concat_2in1_out") auto* concat_topk_out = VarNode("concat_topk_out")
->assert_is_op_output("concat", "Out") ->assert_is_op_output("concat", "Out")
->AsIntermediate(); ->AsIntermediate();
for (int i = 2; i < n_concat_topk_; ++i) {
auto concat_topk_input_name =
paddle::lite::string_format("concat_topk_input%d", i);
auto* concat_topk_inputx = VarNode(concat_topk_input_name)
->assert_is_op_nth_input("concat", "X", i)
->AsInput();
*concat_topk_inputx >> *concat_topk;
}
auto* seq_rev = OpNode("seq_rev", "sequence_reverse")->AsIntermediate(); auto* seq_rev = OpNode("seq_rev", "sequence_reverse")->AsIntermediate();
auto* seq_rev_out = VarNode("seq_rev_out") auto* seq_rev_out = VarNode("seq_rev_out")
->assert_is_op_output("sequence_reverse", "Y") ->assert_is_op_output("sequence_reverse", "Y")
...@@ -1034,9 +1477,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1034,9 +1477,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
*search_fc0_w >> *search_fc0; *search_fc0_w >> *search_fc0;
*search_fc0_b >> *search_fc0; *search_fc0_b >> *search_fc0;
*concat_2in1_input0 >> *concat_2in1; *concat_topk_input0 >> *concat_topk;
*concat_2in1_input1 >> *concat_2in1; *concat_topk_input1 >> *concat_topk;
*concat_2in1 >> *concat_2in1_out >> *seq_rev >> *seq_rev_out; *concat_topk >> *concat_topk_out >> *seq_rev >> *seq_rev_out;
*seq_rev_out >> *grnn_rv >> *grnn_rv_out >> *seq_pool_rv >> *seq_rev_out >> *grnn_rv >> *grnn_rv_out >> *seq_pool_rv >>
*seq_pool_rv_out; *seq_pool_rv_out;
...@@ -1047,7 +1490,7 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1047,7 +1490,7 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
*grnn_rv >> *grnn_rv_layout_input; *grnn_rv >> *grnn_rv_layout_input;
*grnn_rv >> *grnn_rv_tmp_buffer; *grnn_rv >> *grnn_rv_tmp_buffer;
*concat_2in1_out >> *grnn_fw >> *grnn_fw_out >> *seq_pool_fw >> *concat_topk_out >> *grnn_fw >> *grnn_fw_out >> *seq_pool_fw >>
*seq_pool_fw_out; *seq_pool_fw_out;
*seq_pool_fw >> *seq_pool_fw_max_idx; *seq_pool_fw >> *seq_pool_fw_max_idx;
*grnn_fw_wh >> *grnn_fw; *grnn_fw_wh >> *grnn_fw;
...@@ -1075,8 +1518,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1075,8 +1518,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
op_desc.SetType("__xpu__mmdnn_merge_all"); op_desc.SetType("__xpu__mmdnn_merge_all");
auto* concat_7in1_op_info = matched.at("concat_7in1")->stmt()->op_info(); auto* concat_7in1_op_info = matched.at("concat_7in1")->stmt()->op_info();
op_desc.SetInput("concat_7in1_x", concat_7in1_op_info->Input("X")); op_desc.SetInput("concat_7in1_x", concat_7in1_op_info->Input("X"));
auto* concat_2in1_op_info = matched.at("concat_2in1")->stmt()->op_info(); auto* concat_topk_op_info = matched.at("concat_topk")->stmt()->op_info();
op_desc.SetInput("concat_2in1_x", concat_2in1_op_info->Input("X")); op_desc.SetInput("concat_topk_x", concat_topk_op_info->Input("X"));
op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_fw_wh")->arg()->name}); op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_fw_wh")->arg()->name});
op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_fw_wi")->arg()->name}); op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_fw_wi")->arg()->name});
op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_rv_wh")->arg()->name}); op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_rv_wh")->arg()->name});
...@@ -1093,23 +1536,26 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1093,23 +1536,26 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
auto* grnn_fw_op_info = matched.at("grnn_fw")->stmt()->op_info(); auto* grnn_fw_op_info = matched.at("grnn_fw")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wh_maxs", "grnn_fw_wh_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wh_max")); grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_fw_wi_maxs", "grnn_fw_wi_maxs",
grnn_fw_op_info->GetAttr<std::vector<float>>("wi_max")); grnn_fw_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* grnn_rv_op_info = matched.at("grnn_rv")->stmt()->op_info(); auto* grnn_rv_op_info = matched.at("grnn_rv")->stmt()->op_info();
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wh_maxs", "grnn_rv_wh_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wh_max")); grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wh_max"));
op_desc.SetAttr<std::vector<float>>( op_desc.SetAttr<std::vector<float>>(
"grnn_rv_wi_maxs", "grnn_rv_wi_maxs",
grnn_rv_op_info->GetAttr<std::vector<float>>("wi_max")); grnn_rv_op_info->GetAttr<std::vector<float>>("__xpu__wi_max"));
auto* fc0_op_info = matched.at("search_fc0")->stmt()->op_info(); auto* fc0_op_info = matched.at("search_fc0")->stmt()->op_info();
op_desc.SetAttr<float>("fc0_w_max", fc0_op_info->GetAttr<float>("w_max")); op_desc.SetAttr<float>("fc0_w_max",
fc0_op_info->GetAttr<float>("__xpu__w_max"));
auto* fc1_op_info = matched.at("search_fc1")->stmt()->op_info(); auto* fc1_op_info = matched.at("search_fc1")->stmt()->op_info();
op_desc.SetAttr<float>("fc1_w_max", fc1_op_info->GetAttr<float>("w_max")); op_desc.SetAttr<float>("fc1_w_max",
fc1_op_info->GetAttr<float>("__xpu__w_max"));
auto* fc2_op_info = matched.at("search_fc2")->stmt()->op_info(); auto* fc2_op_info = matched.at("search_fc2")->stmt()->op_info();
op_desc.SetAttr<float>("fc2_w_max", fc2_op_info->GetAttr<float>("w_max")); op_desc.SetAttr<float>("fc2_w_max",
fc2_op_info->GetAttr<float>("__xpu__w_max"));
auto* new_stmt = matched.at("concat_7in1")->stmt(); auto* new_stmt = matched.at("concat_7in1")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
...@@ -1120,8 +1566,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1120,8 +1566,8 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
new_stmt->SetKernels(std::move(kernels)); new_stmt->SetKernels(std::move(kernels));
std::vector<std::string> arg_names{ std::vector<std::string> arg_names{
"concat_2in1_input0", "concat_topk_input0",
"concat_2in1_input1", "concat_topk_input1",
"grnn_fw_wh", "grnn_fw_wh",
"grnn_fw_wi", "grnn_fw_wi",
"grnn_rv_wh", "grnn_rv_wh",
...@@ -1133,6 +1579,11 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1133,6 +1579,11 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
"search_fc2_w", "search_fc2_w",
"search_fc2_b", "search_fc2_b",
}; };
for (int i = 2; i < n_concat_topk_; ++i) {
auto concat_topk_input_name =
paddle::lite::string_format("concat_topk_input%d", i);
arg_names.push_back(concat_topk_input_name);
}
for (auto name : arg_names) { for (auto name : arg_names) {
DirectedLink(matched.at(name), matched.at("concat_7in1")); DirectedLink(matched.at(name), matched.at("concat_7in1"));
} }
...@@ -1143,6 +1594,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase { ...@@ -1143,6 +1594,9 @@ class XPUMmdnnMergeAllFuser : public FuseBase {
IR_OP_VAR_LINK(matched.at("concat_7in1"), matched.at(name)); IR_OP_VAR_LINK(matched.at("concat_7in1"), matched.at(name));
} }
} }
private:
int n_concat_topk_;
}; };
} // namespace fusion } // namespace fusion
...@@ -1158,15 +1612,21 @@ class XPUMmdnnFusePass : public ProgramPass { ...@@ -1158,15 +1612,21 @@ class XPUMmdnnFusePass : public ProgramPass {
search_att_fuser(graph.get()); search_att_fuser(graph.get());
fusion::XPUMmdnnMatchConvTopkFuser match_conv_topk_fuser; fusion::XPUMmdnnMatchConvTopkFuser match_conv_topk_fuser;
match_conv_topk_fuser(graph.get()); match_conv_topk_fuser(graph.get());
fusion::XPUMmdnnMatchConvTopkFuser2 match_conv_topk_fuser2;
match_conv_topk_fuser2(graph.get());
fusion::XPUMmdnnBidSeqRevEmbEltwiseFuser bi_seq_rev_emb_eltwise_fuser; fusion::XPUMmdnnBidSeqRevEmbEltwiseFuser bi_seq_rev_emb_eltwise_fuser;
bi_seq_rev_emb_eltwise_fuser(graph.get()); bi_seq_rev_emb_eltwise_fuser(graph.get());
fusion::XPUMmdnnBidEmbGrnnAttFuser bid_emb_grnn_att_fuser; fusion::XPUMmdnnBidEmbGrnnAttFuser bid_emb_grnn_att_fuser;
bid_emb_grnn_att_fuser(graph.get()); bid_emb_grnn_att_fuser(graph.get());
fusion::XPUMmdnnBidEmbGrnnAttFuser2 bid_emb_grnn_att_fuser2;
bid_emb_grnn_att_fuser2(graph.get());
fusion::XPUMmdnnBidEmbAttFuser bid_emb_att_fuser; fusion::XPUMmdnnBidEmbAttFuser bid_emb_att_fuser;
bid_emb_att_fuser(graph.get()); bid_emb_att_fuser(graph.get());
fusion::XPUMmdnnMergeAllFuser merge_all_fuser; for (int n_concat_topk : {3, 2}) {
merge_all_fuser(graph.get()); fusion::XPUMmdnnMergeAllFuser merge_all_fuser(n_concat_topk);
merge_all_fuser(graph.get());
}
} }
}; };
...@@ -1178,6 +1638,7 @@ REGISTER_MIR_PASS(__xpu__mmdnn_fuse_pass, paddle::lite::mir::XPUMmdnnFusePass) ...@@ -1178,6 +1638,7 @@ REGISTER_MIR_PASS(__xpu__mmdnn_fuse_pass, paddle::lite::mir::XPUMmdnnFusePass)
.BindTargets({TARGET(kXPU)}) .BindTargets({TARGET(kXPU)})
.BindKernel("__xpu__mmdnn_search_attention") .BindKernel("__xpu__mmdnn_search_attention")
.BindKernel("__xpu__mmdnn_bid_emb_grnn_att") .BindKernel("__xpu__mmdnn_bid_emb_grnn_att")
.BindKernel("__xpu__mmdnn_bid_emb_grnn_att2")
.BindKernel("__xpu__mmdnn_bid_emb_att") .BindKernel("__xpu__mmdnn_bid_emb_att")
.BindKernel("__xpu__mmdnn_match_conv_topk") .BindKernel("__xpu__mmdnn_match_conv_topk")
.BindKernel("__xpu__mmdnn_merge_all"); .BindKernel("__xpu__mmdnn_merge_all");
...@@ -383,10 +383,10 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -383,10 +383,10 @@ class XPUSingleEncoderFuser : public FuseBase {
op_desc.SetAttr<std::string>("act_type", act_type_); op_desc.SetAttr<std::string>("act_type", act_type_);
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
auto* single_encoder_stmt = matched.at("q_mul")->stmt(); auto* single_encoder_stmt = matched.at("q_mul")->stmt();
fake_subgraph_op->Attach(op_desc, single_encoder_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, single_encoder_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(single_encoder_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(single_encoder_stmt->op()->valid_places());
......
...@@ -373,10 +373,10 @@ class XPUResNetCbamBlock0Fuser : public FuseBase { ...@@ -373,10 +373,10 @@ class XPUResNetCbamBlock0Fuser : public FuseBase {
auto block0_stmt = matched.at("left_conv1")->stmt(); auto block0_stmt = matched.at("left_conv1")->stmt();
// block0_stmt->ResetOp(op_desc, graph->valid_places()); // block0_stmt->ResetOp(op_desc, graph->valid_places());
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places());
block0_stmt->SetOp(fake_subgraph_op); block0_stmt->SetOp(fake_subgraph_op);
...@@ -693,10 +693,10 @@ class XPUResNetCbamBlock1Fuser : public FuseBase { ...@@ -693,10 +693,10 @@ class XPUResNetCbamBlock1Fuser : public FuseBase {
auto block1_stmt = matched.at("right_conv1")->stmt(); auto block1_stmt = matched.at("right_conv1")->stmt();
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places());
block1_stmt->SetOp(fake_subgraph_op); block1_stmt->SetOp(fake_subgraph_op);
...@@ -932,10 +932,10 @@ class XPUResNetCbamBlock2Fuser : public FuseBase { ...@@ -932,10 +932,10 @@ class XPUResNetCbamBlock2Fuser : public FuseBase {
<< "Y of last fc must have been transposed"; << "Y of last fc must have been transposed";
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, scope); fake_subgraph_op->Attach(op_desc, scope);
fake_subgraph_op->SetValidPlaces(block2_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block2_stmt->op()->valid_places());
block2_stmt->SetOp(fake_subgraph_op); block2_stmt->SetOp(fake_subgraph_op);
......
...@@ -315,10 +315,10 @@ class XPUResNetBlock0Fuser : public FuseBase { ...@@ -315,10 +315,10 @@ class XPUResNetBlock0Fuser : public FuseBase {
auto block0_stmt = matched.at("left_conv1")->stmt(); auto block0_stmt = matched.at("left_conv1")->stmt();
// block0_stmt->ResetOp(op_desc, graph->valid_places()); // block0_stmt->ResetOp(op_desc, graph->valid_places());
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places());
block0_stmt->SetOp(fake_subgraph_op); block0_stmt->SetOp(fake_subgraph_op);
...@@ -577,10 +577,10 @@ class XPUResNetBlock1Fuser : public FuseBase { ...@@ -577,10 +577,10 @@ class XPUResNetBlock1Fuser : public FuseBase {
auto block1_stmt = matched.at("right_conv1")->stmt(); auto block1_stmt = matched.at("right_conv1")->stmt();
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
// XXX: memleak? auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); sub_program_desc->AddBlock<cpp::BlockDesc>();
static_cast<operators::SubgraphOp*>(fake_subgraph_op.get()) static_cast<operators::SubgraphOp*>(fake_subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope()); fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope());
fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places()); fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places());
block1_stmt->SetOp(fake_subgraph_op); block1_stmt->SetOp(fake_subgraph_op);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/conv_conv_fuser.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
bool has_arm = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
has_arm = true;
break;
}
}
if (!has_arm) {
return;
}
// only support fp32 fusion
for (auto conv_has_bias0 : conv_has_bias_cases) {
for (auto conv_has_bias1 : conv_has_bias_cases) {
for (auto conv_type0 : conv_type_cases) {
for (auto conv_type1 : conv_type_cases) {
VLOG(4) << "conv_has_bias0:" << conv_has_bias0
<< " conv_type0:" << conv_type0;
VLOG(4) << "conv_has_bias1:" << conv_has_bias1
<< " conv_type1:" << conv_type1;
fusion::ConvConvFuser fuser(
conv_type0, conv_type1, conv_has_bias0, conv_has_bias1);
fuser(graph.get());
}
}
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_conv_fuse_pass, paddle::lite::mir::ConvConvFusePass)
.BindTargets({TARGET(kARM)});
...@@ -14,18 +14,19 @@ ...@@ -14,18 +14,19 @@
#pragma once #pragma once
#include "lite/backends/xpu/xpu_header_sitter.h" #include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace mir {
namespace xpu {
struct XPUFreeDeleter { class ConvConvFusePass : public ProgramPass {
void operator()(void* p) const { xpu_free(p); } public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
}; };
} // namespace xpu } // namespace mir
} // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuser.h"
#include <memory>
#include <set>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvConvFuser::BuildPattern() {
auto* conv_input0 = VarNode("conv_input0")
->assert_is_op_input(conv_type0_, "Input")
->AsInput();
auto* conv_weight0 = VarNode("conv_weight0")
->assert_is_op_input(conv_type0_, "Filter")
->AsInput();
auto* conv0 = OpNode("conv2d0", conv_type0_)->assert_is_op(conv_type0_);
auto* conv_out0 = VarNode("conv_out0")
->assert_is_op_output(conv_type0_, "Output")
->assert_is_op_input(conv_type1_, "Input")
->AsIntermediate();
auto* conv_weight1 = VarNode("conv_weight1")
->assert_is_op_input(conv_type1_, "Filter")
->AsIntermediate();
auto* conv1 = OpNode("conv2d1", conv_type1_)
->assert_is_op(conv_type1_)
->assert_op_attr<int>("groups", 1)
->AsIntermediate();
auto* conv_out1 = VarNode("conv_out1")
->assert_is_op_output(conv_type1_, "Output")
->AsOutput();
if (conv_has_bias0_) {
if (conv_has_bias1_) {
auto* conv_bias0 = VarNode("conv_bias0")
->assert_is_op_input(conv_type0_, "Bias")
->AsIntermediate();
auto* conv_bias1 = VarNode("conv_bias1")
->assert_is_op_input(conv_type1_, "Bias")
->AsInput();
conv0->LinksFrom({conv_input0, conv_weight0, conv_bias0})
.LinksTo({conv_out0});
conv1->LinksFrom({conv_out0, conv_weight1, conv_bias1})
.LinksTo({conv_out1});
} else {
auto* conv_bias0 = VarNode("conv_bias0")
->assert_is_op_input(conv_type0_, "Bias")
->AsIntermediate();
conv0->LinksFrom({conv_input0, conv_weight0, conv_bias0})
.LinksTo({conv_out0});
conv1->LinksFrom({conv_out0, conv_weight1}).LinksTo({conv_out1});
}
} else {
conv0->LinksFrom({conv_input0, conv_weight0}).LinksTo({conv_out0});
if (conv_has_bias1_) {
auto* conv_bias1 = VarNode("conv_bias1")
->assert_is_op_input(conv_type1_, "Bias")
->AsInput();
conv1->LinksFrom({conv_out0, conv_weight1, conv_bias1})
.LinksTo({conv_out1});
} else {
conv1->LinksFrom({conv_out0, conv_weight1}).LinksTo({conv_out1});
}
}
}
void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto conv_instruct = matched.at("conv2d0")->stmt();
auto conv_op_desc = conv_instruct->mutable_op_info();
auto conv = conv_instruct->op();
auto* scope = conv->scope();
auto conv_op_desc1 = matched.at("conv2d1")->stmt()->mutable_op_info();
// conv0
auto weight0_t = scope->FindVar(matched.at("conv_weight0")->arg()->name)
->GetMutable<lite::Tensor>();
// conv1
auto weight1_t = scope->FindVar(matched.at("conv_weight1")->arg()->name)
->GetMutable<lite::Tensor>();
// auto groups0 = conv_op_desc->GetAttr<int>("groups");
auto groups1 = conv_op_desc1->GetAttr<int>("groups");
auto strides1 = conv_op_desc1->GetAttr<std::vector<int>>("strides");
auto paddings1 = conv_op_desc1->GetAttr<std::vector<int>>("paddings");
auto dilations1 = conv_op_desc1->GetAttr<std::vector<int>>("dilations");
bool enable0_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool enable1_int8 = conv_op_desc1->HasAttr("enable_int8") ? true : false;
int kw = weight1_t->dims()[2];
int kh = weight1_t->dims()[3];
if (!(kw == 1 && kh == 1)) {
return;
}
CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same";
CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1";
CHECK_EQ(weight0_t->dims()[0], weight1_t->dims()[1])
<< "weight0_dims[0] == weight1_dim[1]";
for (int i = 0; i < strides1.size(); i++) {
CHECK_EQ(strides1[i], 1) << "strides[" << i << "]: " << strides1[i]
<< " must be 1";
}
for (int i = 0; i < paddings1.size(); i++) {
CHECK_EQ(paddings1[i], 0) << "paddings1[" << i << "]: " << paddings1[i]
<< " must be 0";
}
for (int i = 0; i < dilations1.size(); i++) {
CHECK_EQ(dilations1[i], 1) << "dilations1[" << i << "]: " << dilations1[i]
<< " must be 1";
}
// comupte new_wight and new bias
///////////////////////////////////////////////////////////////////////////////
// Compute ConvConvFuser
// Before fusion
//
// conv(x) = conv(x) = kx + z = y
// conv(y) = ay + b
//
// After fusion:
//
// conv(conv(x)) = a(kx + z) + b = akx + az + b
//
// new_weights = ak
// new_bias = az + b
///////////////////////////////////////////////////////////////////////////////
if (enable0_int8) {
LOG(FATAL) << "it doesn't support";
return;
} else {
// compute new conv_weight
Tensor weight_tensor;
auto in_dims = weight0_t->dims();
auto weight_dims = weight1_t->dims();
const float* din = weight0_t->data<float>();
const float* weights = weight1_t->data<float>();
int oc0 = in_dims[0];
int ic = in_dims[1];
int ih = in_dims[2];
int iw = in_dims[3];
int oc = weight_dims[0];
weight_tensor.Resize({oc, ic, ih, iw});
float* dout = weight_tensor.mutable_data<float>();
ComputeNewWeight(dout, din, weights, oc0, ic, ih, iw, oc);
weight0_t->CopyDataFrom(weight_tensor);
}
// compute new conv_bias
if (conv_has_bias0_ && conv_op_desc->HasInput("Bias") &&
conv_op_desc->Input("Bias").size() > 0) {
auto bias_t0 = scope->FindVar(matched.at("conv_bias0")->arg()->name)
->GetMutable<lite::Tensor>();
if (conv_has_bias1_ && conv_op_desc1->HasInput("Bias") &&
conv_op_desc1->Input("Bias").size() > 0) {
auto bias_t1 = scope->FindVar(matched.at("conv_bias1")->arg()->name)
->GetMutable<lite::Tensor>();
Tensor bias;
bias.CopyDataFrom(*bias_t1);
auto bias_data = bias.mutable_data<float>();
ComputeNewBias(bias_data, bias_t0, weight1_t, bias_t1);
bias_t1->CopyDataFrom(bias);
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias1")->arg()->name}); // conv_bias
IR_NODE_LINK_TO(matched.at("conv_bias1"), matched.at("conv2d0"));
} else {
Tensor bias;
auto weight_dims = weight1_t->dims();
bias.Resize({weight_dims[0]});
auto bias_d = bias.mutable_data<float>();
ComputeNewBias(bias_d, bias_t0, weight1_t, nullptr);
bias_t0->CopyDataFrom(bias);
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias0")->arg()->name}); // conv_bias
}
} else {
if (conv_has_bias1_ && conv_op_desc1->HasInput("Bias") &&
conv_op_desc1->Input("Bias").size() > 0) {
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias1")->arg()->name}); // conv_bias
IR_NODE_LINK_TO(matched.at("conv_bias1"), matched.at("conv2d0"));
}
}
conv_op_desc->SetType(conv_type0_);
conv_op_desc->SetInput("Input", {matched.at("conv_input0")->arg()->name});
conv_op_desc->SetInput("Filter", {matched.at("conv_weight0")->arg()->name});
conv_op_desc->SetOutput("Output", {matched.at("conv_out1")->arg()->name});
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
IR_OP_VAR_LINK(matched.at("conv2d0"), matched.at("conv_out1"));
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvConvFuser : public FuseBase {
public:
explicit ConvConvFuser(const std::string& conv_type0,
const std::string& conv_type1,
const bool conv_has_bias0,
const bool conv_has_bias1)
: conv_type0_(conv_type0),
conv_type1_(conv_type1),
conv_has_bias0_(conv_has_bias0),
conv_has_bias1_(conv_has_bias1) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
void ComputeNewWeight(float* dout,
const float* din,
const float* weights,
int oc0,
int ic,
int ih,
int iw,
int oc1) {
// input conv_weight0_t weights conv_weight1_t
// output weight_tensor
// ksize = 1
int in_size = ih * iw;
int in_channel_size = ic * in_size;
// out = w1[j, i, ih, iw] * w2[k, j, kw, kh]
// out_dim = [oc1, ic, kh, kw], din_dim = [oc0, ic, kh, kw]
// weight_dim = [oc1, oc0, kh, kw]
for (int k = 0; k < oc1; k++) {
const float* weights_ptr = weights + k * oc0;
float* out_ptr = dout + k * in_channel_size;
for (int c = 0; c < ic; c++) {
float* out_ptr_channel = out_ptr + c * in_size;
const float* din_ptr = din + c * in_size;
for (int i = 0; i < in_size; i++) {
float sum = 0.f;
for (int j = 0; j < oc0; j++) {
sum += din_ptr[j * in_channel_size] * weights_ptr[j];
}
*out_ptr_channel++ = sum;
}
}
}
}
void ComputeNewBias(float* dout,
Tensor* bias0_tensor,
Tensor* weight_tensor,
Tensor* bias1_tensor) {
// input bias0_tensor weight_tensor bias1_tensor
// output bias_tensor
auto in_dims = bias0_tensor->dims();
auto weight_dims = weight_tensor->dims();
const float* din = bias0_tensor->data<float>();
const float* weights = weight_tensor->data<float>();
int ic = in_dims[0];
int oc = weight_dims[0];
// out_k = b0[num, j, 1, 1] * w2[k, j, 1, 1]
if (bias1_tensor) {
const float* din2 = bias1_tensor->data<float>();
for (int k = 0; k < oc; k++) {
const float* weights_ptr = weights + k * ic;
float sum = 0.f;
for (int j = 0; j < ic; j++) {
sum += din[j] * weights_ptr[j];
}
dout[k] = sum + din2[k];
}
} else {
for (int k = 0; k < oc; k++) {
const float* weights_ptr = weights + k * ic;
float sum = 0.f;
for (int j = 0; j < ic; j++) {
sum += din[j] * weights_ptr[j];
}
dout[k] = sum;
}
}
}
private:
std::string conv_type0_{"conv2d"};
std::string conv_type1_{"conv2d"};
bool conv_has_bias0_{false};
bool conv_has_bias1_{false};
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -175,7 +175,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -175,7 +175,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
for (int i = 0; i < weight_scale_size; i++) { for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale); weight_scale.push_back(whole_weight_scale);
} }
op_desc.SetAttr("enable_int8", true);
// Arm CPU does not support conv2d_transpose
if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true);
}
op_desc.SetInputScale(weight_name, weight_scale); op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type. // change the weight from the float type to int8 type.
...@@ -280,6 +284,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -280,6 +284,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
} }
// Arm CPU does not support conv2d_transpose
if (quantized_op_type_ != "conv2d_transpose") { if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
} }
......
...@@ -39,6 +39,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -39,6 +39,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
nodes_in_order = graph->StmtTopologicalOrder(); nodes_in_order = graph->StmtTopologicalOrder();
} }
insts_.emplace_back();
for (auto& item : nodes_in_order) { for (auto& item : nodes_in_order) {
if (item->IsStmt()) { if (item->IsStmt()) {
auto& stmt = item->AsStmt(); auto& stmt = item->AsStmt();
...@@ -57,7 +58,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -57,7 +58,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
.SetSyncStreams(stmt.sync_streams_); .SetSyncStreams(stmt.sync_streams_);
} }
#endif #endif
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); insts_.back().emplace_back(stmt.op(), std::move(stmt.kernels().front()));
} }
} }
} }
......
...@@ -42,7 +42,7 @@ class GenerateProgramPass : public ProgramPass { ...@@ -42,7 +42,7 @@ class GenerateProgramPass : public ProgramPass {
} }
private: private:
std::vector<Instruction> insts_; std::vector<std::vector<Instruction>> insts_;
}; };
} // namespace mir } // namespace mir
......
...@@ -284,13 +284,19 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph, ...@@ -284,13 +284,19 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
head_node->AsArg().name, head_node->AsArg().name,
cur_node->AsArg().name); cur_node->AsArg().name);
// for subgraph op, modify the BlockDesc // for subgraph op, modify the BlockDesc
auto* sub_block_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>( auto sub_program_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>(
inst_node->AsStmt().op().get()) inst_node->AsStmt().op().get())
->GetSubBlock(); ->GetProgramDesc();
for (size_t i = 0; i < sub_block_desc->OpsSize(); ++i) { CHECK(sub_program_desc);
auto* sub_block_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(i); int sub_block_idx =
UpdateInputTo( inst_node->AsStmt().op()->op_info()->GetAttr<int32_t>("sub_block");
sub_block_op_desc, head_node->AsArg().name, cur_node->AsArg().name); auto* sub_block_desc =
sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx);
for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
++sub_op_idx) {
auto* sub_op_desc = const_cast<cpp::OpDesc*>(
sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx));
UpdateInputTo(sub_op_desc, head_node->AsArg().name, cur_node->AsArg().name);
} }
// recreate the op // recreate the op
...@@ -444,21 +450,27 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph, ...@@ -444,21 +450,27 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
tail_node->AsArg().name, tail_node->AsArg().name,
cur_node->AsArg().name); cur_node->AsArg().name);
// for subgraph op, modify the BlockDesc // for subgraph op, modify the BlockDesc
auto* sub_block_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>( auto sub_program_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>(
inst_node->AsStmt().op().get()) inst_node->AsStmt().op().get())
->GetSubBlock(); ->GetProgramDesc();
for (size_t i = 0; i < sub_block_desc->OpsSize(); ++i) { CHECK(sub_program_desc);
auto* sub_block_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(i); int sub_block_idx =
inst_node->AsStmt().op()->op_info()->GetAttr<int32_t>("sub_block");
auto* sub_block_desc =
sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx);
for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
++sub_op_idx) {
auto* sub_op_desc = const_cast<cpp::OpDesc*>(
sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx));
UpdateOutputTo( UpdateOutputTo(
sub_block_op_desc, tail_node->AsArg().name, cur_node->AsArg().name); sub_op_desc, tail_node->AsArg().name, cur_node->AsArg().name);
/* graph like this /* graph like this
* subgraph_op_0 * subgraph_op_0
* / \ * / \
* / \ * / \
* subgraph_op_1 host_op * subgraph_op_1 host_op
*/ */
UpdateInputTo( UpdateInputTo(sub_op_desc, tail_node->AsArg().name, cur_node->AsArg().name);
sub_block_op_desc, tail_node->AsArg().name, cur_node->AsArg().name);
} }
// recreate the op // recreate the op
...@@ -482,15 +494,22 @@ void MLUPostprocessPass::RecreateOp(Node* inst_node, SSAGraph* graph) { ...@@ -482,15 +494,22 @@ void MLUPostprocessPass::RecreateOp(Node* inst_node, SSAGraph* graph) {
} }
} }
bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node, Node* inst) { bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node,
auto* block_desc = Node* inst_node) {
static_cast<operators::SubgraphOp*>(inst->AsStmt().op().get()) auto sub_program_desc = dynamic_cast<paddle::lite::operators::SubgraphOp*>(
->GetSubBlock(); inst_node->AsStmt().op().get())
for (size_t op_idx = 0; op_idx < block_desc->OpsSize(); op_idx++) { ->GetProgramDesc();
auto op_desc = block_desc->GetOp<cpp::OpDesc>(op_idx); CHECK(sub_program_desc);
CHECK(op_desc); int sub_block_idx =
if (op_desc->Type() == "conv2d") { inst_node->AsStmt().op()->op_info()->GetAttr<int32_t>("sub_block");
for (auto& names : op_desc->inputs()) { auto* sub_block_desc =
sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx);
for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
sub_op_idx++) {
auto sub_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx);
CHECK(sub_op_desc);
if (sub_op_desc->Type() == "conv2d") {
for (auto& names : sub_op_desc->inputs()) {
if (std::find(names.second.begin(), if (std::find(names.second.begin(),
names.second.end(), names.second.end(),
arg_node->AsArg().name) != names.second.end()) { arg_node->AsArg().name) != names.second.end()) {
...@@ -746,19 +765,23 @@ std::pair<bool, std::string> CheckOutputAndInsert( ...@@ -746,19 +765,23 @@ std::pair<bool, std::string> CheckOutputAndInsert(
// insert cast op on mlu, to avoid cast on cpu // insert cast op on mlu, to avoid cast on cpu
void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
const Type* subgraph_type) { const Type* subgraph_type) {
auto subgraph_op = subgraph_node->AsStmt().op(); CHECK_EQ(subgraph_node->AsStmt().op()->Type(), "subgraph");
CHECK_EQ(subgraph_op->Type(), "subgraph"); auto subgraph_op =
auto op = dynamic_cast<operators::SubgraphOp*>(subgraph_op.get()); dynamic_cast<operators::SubgraphOp*>(subgraph_node->AsStmt().op().get());
CHECK(op); CHECK(subgraph_op);
auto block_desc = op->GetSubBlock(); auto sub_program_desc = subgraph_op->GetProgramDesc();
CHECK(sub_program_desc);
int sub_block_idx = subgraph_op->op_info()->GetAttr<int32_t>("sub_block");
auto* sub_block_desc = const_cast<cpp::BlockDesc*>(
sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx));
// create a new block desc to keep op sequence correct // create a new block desc to keep op sequence correct
cpp::BlockDesc* new_block_desc = new cpp::BlockDesc(); cpp::BlockDesc new_block_desc;
new_block_desc->ClearOps(); new_block_desc.ClearOps();
new_block_desc->ClearVars(); new_block_desc.ClearVars();
new_block_desc->SetIdx(block_desc->Idx()); new_block_desc.SetIdx(sub_block_desc->Idx());
new_block_desc->SetParentIdx(block_desc->ParentIdx()); new_block_desc.SetParentIdx(sub_block_desc->ParentIdx());
new_block_desc->SetForwardBlockIdx(block_desc->ForwardBlockIdx()); new_block_desc.SetForwardBlockIdx(sub_block_desc->ForwardBlockIdx());
// find all IO that is not weight or persist // find all IO that is not weight or persist
std::list<std::string> i_names, o_names; std::list<std::string> i_names, o_names;
...@@ -769,8 +792,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -769,8 +792,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
auto input_name = input->AsArg().name; auto input_name = input->AsArg().name;
if (!(input->AsArg().is_weight || input->AsArg().is_persist)) { if (!(input->AsArg().is_weight || input->AsArg().is_persist)) {
i_names.emplace_back(input_name); i_names.emplace_back(input_name);
auto ret = CheckInputAndInsert(op->scope(), auto ret = CheckInputAndInsert(subgraph_op->scope(),
new_block_desc, &new_block_desc,
input_name, input_name,
input->AsArg().type, input->AsArg().type,
subgraph_type); subgraph_type);
...@@ -783,8 +806,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -783,8 +806,8 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
auto output_name = output->AsArg().name; auto output_name = output->AsArg().name;
if (!(output->AsArg().is_weight || output->AsArg().is_persist)) { if (!(output->AsArg().is_weight || output->AsArg().is_persist)) {
o_names.emplace_back(output_name); o_names.emplace_back(output_name);
auto ret = CheckOutputAndInsert(op->scope(), auto ret = CheckOutputAndInsert(subgraph_op->scope(),
block_desc, sub_block_desc,
output_name, output_name,
output->AsArg().type, output->AsArg().type,
subgraph_type); subgraph_type);
...@@ -795,46 +818,48 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -795,46 +818,48 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
} }
// update input and output // update input and output
for (size_t op_idx = 0; op_idx < block_desc->OpsSize(); ++op_idx) { for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
auto desc = block_desc->GetOp<cpp::OpDesc>(op_idx); ++sub_op_idx) {
auto new_desc = new_block_desc->AddOp<cpp::OpDesc>(); auto sub_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx);
*new_desc = *desc; auto new_op_desc = new_block_desc.AddOp<cpp::OpDesc>();
*new_op_desc = *sub_op_desc;
if (desc->Type() != "layout" && desc->Type() != "cast") {
auto op_input_args = new_desc->InputArgumentNames(); if (sub_op_desc->Type() != "layout" && sub_op_desc->Type() != "cast") {
auto op_input_args = new_op_desc->InputArgumentNames();
for (auto& input_arg : op_input_args) { for (auto& input_arg : op_input_args) {
auto op_input = new_desc->Input(input_arg); auto op_input = new_op_desc->Input(input_arg);
for (auto& it : i_names) { for (auto& it : i_names) {
auto index = std::find(op_input.begin(), op_input.end(), it); auto index = std::find(op_input.begin(), op_input.end(), it);
if (index != op_input.end() && if (index != op_input.end() &&
node_replace.find(it) != node_replace.end()) { node_replace.find(it) != node_replace.end()) {
index = op_input.erase(index); index = op_input.erase(index);
op_input.emplace(index, node_replace.at(it)); op_input.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change input from " << it VLOG(4) << new_op_desc->Type() << "] change input from " << it
<< " to " << node_replace.at(it); << " to " << node_replace.at(it);
} }
} }
new_desc->SetInput(input_arg, op_input); new_op_desc->SetInput(input_arg, op_input);
} }
auto op_output_args = new_desc->OutputArgumentNames(); auto op_output_args = new_op_desc->OutputArgumentNames();
for (auto& output_arg : op_output_args) { for (auto& output_arg : op_output_args) {
auto op_output = new_desc->Output(output_arg); auto op_output = new_op_desc->Output(output_arg);
for (auto& it : o_names) { for (auto& it : o_names) {
auto index = std::find(op_output.begin(), op_output.end(), it); auto index = std::find(op_output.begin(), op_output.end(), it);
if (index != op_output.end() && if (index != op_output.end() &&
node_replace.find(it) != node_replace.end()) { node_replace.find(it) != node_replace.end()) {
index = op_output.erase(index); index = op_output.erase(index);
op_output.emplace(index, node_replace.at(it)); op_output.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change output from " << it VLOG(4) << new_op_desc->Type() << "] change output from " << it
<< " to " << node_replace.at(it); << " to " << node_replace.at(it);
} }
} }
new_desc->SetOutput(output_arg, op_output); new_op_desc->SetOutput(output_arg, op_output);
} }
} }
} }
op->SetSubBlock(new_block_desc);
*sub_block_desc = new_block_desc;
} }
void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) { void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) {
......
...@@ -153,60 +153,61 @@ Node *SSAGraph::GraphCreateInstructNode( ...@@ -153,60 +153,61 @@ Node *SSAGraph::GraphCreateInstructNode(
} }
void SSAGraph::Build(const Program &program, void SSAGraph::Build(const Program &program,
const std::vector<Place> &valid_places) { const std::vector<Place> &valid_places,
int block_idx) {
CHECK(node_storage_.empty()); CHECK(node_storage_.empty());
auto weights_name = program.weights(); auto weights = program.weights();
auto is_weights = [&](const std::string &name) -> bool { auto is_weight = [&](const std::string &name) -> bool {
auto it = std::find(weights_name.begin(), weights_name.end(), name); auto it = std::find(weights.begin(), weights.end(), name);
if (it == weights_name.end()) return false; if (it == weights.end()) return false;
return true; return true;
}; };
std::map<std::string, PrecisionType> var_types = program.var_data_type(); auto var_type_map = program.var_type_map();
std::map<std::string, mir::Node *> arg_update_node_map;
std::map<std::string, mir::Node *> arg_update_node_map_; for (auto &op : program.ops(block_idx)) {
for (auto &op : program.ops()) {
VLOG(3) << op->op_info()->Type(); VLOG(3) << op->op_info()->Type();
auto *op_node = GraphCreateInstructNode(op, valid_places); auto *op_node = GraphCreateInstructNode(op, valid_places);
for (const std::string &name : op->op_info()->input_names()) { auto *op_info = op->op_info();
const auto &op_type = op_info->Type();
for (const auto &var_name : op_info->input_names()) {
mir::Node *arg_node = nullptr; mir::Node *arg_node = nullptr;
if (arg_update_node_map_.count(name)) { if (arg_update_node_map.count(var_name)) {
arg_node = arg_update_node_map_.at(name); arg_node = arg_update_node_map.at(var_name);
} else { } else {
node_storage_.emplace_back(); node_storage_.emplace_back();
arg_node = &node_storage_.back(); arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(var_name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map[var_name] = arg_node;
} }
if (var_types.count(name)) { if (var_type_map.count(var_name)) {
if (!arg_node->arg()->type) { if (!arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy( arg_node->arg()->type = var_type_map[var_name];
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
} }
// Store the original data type of the output tensors for // Store the original data type of the output tensors for
// type_precision_cast_pass, to keep the consistency between the // type_precision_cast_pass, to keep the consistency between the
// output types of original graph and optimized graph's // output types of original graph and optimized graph's
if (op->op_info()->Type() == "fetch") { if (op_type == "fetch") {
op->mutable_op_info()->SetAttr<int>( op->mutable_op_info()->SetAttr<int>(
"data_type", static_cast<int>(var_types[name])); "data_type",
static_cast<int>(var_type_map[var_name]->precision()));
} }
} }
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weight(var_name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
DirectedLink(arg_node, op_node); DirectedLink(arg_node, op_node);
} }
for (const std::string &name : op->op_info()->output_names()) { for (const auto &var_name : op->op_info()->output_names()) {
node_storage_.emplace_back(); node_storage_.emplace_back();
auto *arg_node = &node_storage_.back(); auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(var_name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map[var_name] = arg_node;
if (var_types.count(name) && !arg_node->arg()->type) { if (var_type_map.count(var_name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy( arg_node->arg()->type = var_type_map[var_name];
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
} }
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weight(var_name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
DirectedLink(op_node, arg_node); DirectedLink(op_node, arg_node);
} }
......
...@@ -35,9 +35,13 @@ class GraphBase {}; ...@@ -35,9 +35,13 @@ class GraphBase {};
class SSAGraph : GraphBase { class SSAGraph : GraphBase {
public: public:
// @param program: the op program // @param program: the target program with vars and ops
// @param valid_places: the valid places user set for the system. // @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places); // @param block_idx: the block index in the target program, default is 0(main
// block)
void Build(const Program &program,
const std::vector<Place> &valid_places,
int block_idx = kRootBlockIdx);
void RemoveNode(const mir::Node *node); void RemoveNode(const mir::Node *node);
std::vector<mir::Node *> StmtTopologicalOrder(); std::vector<mir::Node *> StmtTopologicalOrder();
......
...@@ -411,16 +411,17 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -411,16 +411,17 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
cpp::OpDesc subgraph_op_desc; cpp::OpDesc subgraph_op_desc;
subgraph_op_desc.SetType("subgraph"); subgraph_op_desc.SetType("subgraph");
// Create a new sub block desc for storing all of Ops and Vars of the target // Create a program desc and a block desc for storing all of Ops and Vars of
// subgraph and sub_block_idx is set as a attribute of subgraph op, // the target subgraph and sub_block_idx is set as a attribute of subgraph op,
// sub_block_idx < 0 means it's a new subgraph op // sub_block_idx = 0 means it's a new subgraph op
int sub_block_idx = -(subgraph_idx + 1); auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
auto sub_block_desc = new cpp::BlockDesc(); int sub_block_idx = 0;
auto sub_block_desc = sub_program_desc->AddBlock<cpp::BlockDesc>();
sub_block_desc->ClearOps(); sub_block_desc->ClearOps();
sub_block_desc->ClearVars(); sub_block_desc->ClearVars();
for (auto &op_node : subgraph_nodes) { for (auto &op_node : subgraph_nodes) {
auto sub_block_op_desc = sub_block_desc->AddOp<cpp::OpDesc>(); auto sub_op_desc = sub_block_desc->AddOp<cpp::OpDesc>();
*sub_block_op_desc = *op_node->AsStmt().op_info(); *sub_op_desc = *op_node->AsStmt().op_info();
} }
subgraph_op_desc.SetAttr<int32_t>("sub_block", sub_block_idx); subgraph_op_desc.SetAttr<int32_t>("sub_block", sub_block_idx);
...@@ -437,13 +438,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -437,13 +438,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
&local_var_nodes, &local_var_nodes,
&unused_var_nodes); &unused_var_nodes);
// A simplified model without the original weight/local/unused nodes on the // A simplified model without the original weight/local/unused nodes on the
// subgraph ops will be saved only if 'SUBGRAPH_DISABLE_ONLINE_MODE' is set to // subgraph ops will be saved only if 'SUBGRAPH_ONLINE_MODE' is set to
// true and Predictor->Run(...), Predictor->Save(...) is called. // true(default) and Predictor->Run(...), Predictor->Save(...) is called.
std::set<Node *> input_var_nodes(idata_var_nodes.begin(), std::set<Node *> input_var_nodes(idata_var_nodes.begin(),
idata_var_nodes.end()); idata_var_nodes.end());
std::set<Node *> output_var_nodes(odata_var_nodes.begin(), std::set<Node *> output_var_nodes(odata_var_nodes.begin(),
odata_var_nodes.end()); odata_var_nodes.end());
if (!GetBoolFromEnv(SUBGRAPH_DISABLE_ONLINE_MODE)) { if (GetBoolFromEnv(SUBGRAPH_ONLINE_MODE, true)) {
input_var_nodes.insert(weight_var_nodes.begin(), weight_var_nodes.end()); input_var_nodes.insert(weight_var_nodes.begin(), weight_var_nodes.end());
output_var_nodes.insert(local_var_nodes.begin(), local_var_nodes.end()); output_var_nodes.insert(local_var_nodes.begin(), local_var_nodes.end());
output_var_nodes.insert(unused_var_nodes.begin(), unused_var_nodes.end()); output_var_nodes.insert(unused_var_nodes.begin(), unused_var_nodes.end());
...@@ -476,7 +477,7 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -476,7 +477,7 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
subgraph_op_desc.SetOutput("Outputs", output_var_names); subgraph_op_desc.SetOutput("Outputs", output_var_names);
auto subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto subgraph_op = LiteOpRegistry::Global().Create("subgraph");
static_cast<operators::SubgraphOp *>(subgraph_op.get()) static_cast<operators::SubgraphOp *>(subgraph_op.get())
->SetSubBlock(sub_block_desc); ->SetProgramDesc(sub_program_desc);
auto any_op = (*subgraph_nodes.begin())->AsStmt().op(); auto any_op = (*subgraph_nodes.begin())->AsStmt().op();
subgraph_op->Attach(subgraph_op_desc, any_op->scope()); subgraph_op->Attach(subgraph_op_desc, any_op->scope());
......
...@@ -141,12 +141,11 @@ std::vector<std::string> AddFetchDesc( ...@@ -141,12 +141,11 @@ std::vector<std::string> AddFetchDesc(
} }
TEST(Subgraph, detect_simple_model) { TEST(Subgraph, detect_simple_model) {
cpp::ProgramDesc program_desc; auto program_desc = std::make_shared<cpp::ProgramDesc>();
std::vector<Place> valid_places{{TARGET(kHost), PRECISION(kFloat)}}; std::vector<Place> valid_places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
// Build a simple network // Build a simple network
program_desc.ClearBlocks(); auto* block_desc = program_desc->AddBlock<cpp::BlockDesc>();
auto* block_desc = program_desc.AddBlock<cpp::BlockDesc>();
block_desc->ClearOps(); block_desc->ClearOps();
block_desc->ClearVars(); block_desc->ClearVars();
auto* var_desc = block_desc->AddVar<cpp::VarDesc>(); auto* var_desc = block_desc->AddVar<cpp::VarDesc>();
...@@ -181,13 +180,13 @@ TEST(Subgraph, detect_custom_model) { ...@@ -181,13 +180,13 @@ TEST(Subgraph, detect_custom_model) {
"the path of model files."; "the path of model files.";
return; return;
} }
cpp::ProgramDesc program_desc; auto program_desc = std::make_shared<cpp::ProgramDesc>();
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
LoadModelPb(FLAGS_model_dir, LoadModelPb(FLAGS_model_dir,
FLAGS_model_file, FLAGS_model_file,
FLAGS_params_file, FLAGS_params_file,
scope.get(), scope.get(),
&program_desc, program_desc.get(),
!FLAGS_model_file.empty() && !FLAGS_params_file.empty(), !FLAGS_model_file.empty() && !FLAGS_params_file.empty(),
false); false);
std::vector<Place> valid_places({ std::vector<Place> valid_places({
......
...@@ -36,14 +36,20 @@ void UpdateInputsForSubgraph(OpLite* op, ...@@ -36,14 +36,20 @@ void UpdateInputsForSubgraph(OpLite* op,
op_desc->GetAttr<std::vector<std::string>>("input_data_names"); op_desc->GetAttr<std::vector<std::string>>("input_data_names");
std::replace(input_data_names.begin(), input_data_names.end(), from, to); std::replace(input_data_names.begin(), input_data_names.end(), from, to);
op_desc->SetAttr("input_data_names", input_data_names); op_desc->SetAttr("input_data_names", input_data_names);
auto* subblock_desc = static_cast<operators::SubgraphOp*>(op)->GetSubBlock(); auto sub_program_desc =
CHECK(subblock_desc); static_cast<operators::SubgraphOp*>(op)->GetProgramDesc();
for (size_t i = 0; i < subblock_desc->OpsSize(); i++) { CHECK(sub_program_desc);
auto* subblock_op_desc = subblock_desc->GetOp<cpp::OpDesc>(i); int sub_block_idx = op_desc->GetAttr<int32_t>("sub_block");
for (auto& subblock_op_input : *subblock_op_desc->mutable_inputs()) { auto sub_block_desc =
for (auto& subblock_var_name : subblock_op_input.second) { sub_program_desc->GetBlock<cpp::BlockDesc>(sub_block_idx);
if (subblock_var_name == from) { for (size_t sub_op_idx = 0; sub_op_idx < sub_block_desc->OpsSize();
subblock_var_name = to; sub_op_idx++) {
auto sub_op_desc = const_cast<cpp::OpDesc*>(
sub_block_desc->GetOp<cpp::OpDesc>(sub_op_idx));
for (auto& sub_op_input : *sub_op_desc->mutable_inputs()) {
for (auto& sub_var_name : sub_op_input.second) {
if (sub_var_name == from) {
sub_var_name = to;
} }
} }
} }
......
...@@ -59,25 +59,46 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -59,25 +59,46 @@ class VariablePlaceInferencePass : public DebugPass {
} }
// Set the type of the weight // Set the type of the weight
void SetWeightType(Node* w, void SetWeightType(Node* weight_node,
const LiteType& type, const LiteType& type,
const std::map<std::string, bool>& lite_with_targets) { const std::map<std::string, bool>& with_targets) {
VLOG(4) << "type.precision():" << PrecisionRepr(type.precision()); VLOG(4) << "type.precision():" << PrecisionRepr(type.precision());
if (lite_with_targets.at("kFPGA")) { if (with_targets.at("kFPGA")) {
w->AsArg().type = LiteType::GetTensorTy( weight_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
} else if (lite_with_targets.at("kOpenCL")) { } else if (with_targets.at("kOpenCL")) {
w->AsArg().type = LiteType::GetTensorTy( weight_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
} else if (lite_with_targets.at("kCUDA")) { } else if (with_targets.at("kCUDA")) {
w->AsArg().type = LiteType::GetTensorTy( weight_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
} else { } else {
w->AsArg().type = LiteType::GetTensorTy( weight_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), type.precision(), DATALAYOUT(kNCHW)); TARGET(kHost), type.precision(), DATALAYOUT(kNCHW));
} }
} }
// Update a's kUnk fields from b's fields.
void UpdateTypeFrom(const Type** a, const Type* b) {
auto target = (*a)->target();
auto precision = (*a)->precision();
auto layout = (*a)->layout();
if (target == TARGET(kUnk)) {
target = b->target();
}
if (precision == PRECISION(kUnk)) {
precision = b->precision();
}
if (layout == DATALAYOUT(kUnk)) {
layout = b->layout();
}
if ((*a)->IsTensor() && b->IsTensor()) {
*a = LiteType::GetTensorTy(target, precision, layout);
} else if ((*a)->IsTensorList() && b->IsTensorList()) {
*a = LiteType::GetTensorListTy(target, precision, layout);
}
}
void InferenceArgumentPlace(SSAGraph* graph) { void InferenceArgumentPlace(SSAGraph* graph) {
auto& valid_places = graph->valid_places(); auto& valid_places = graph->valid_places();
auto valid_places_has_target = [&](TargetType t) -> bool { auto valid_places_has_target = [&](TargetType t) -> bool {
...@@ -88,122 +109,90 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -88,122 +109,90 @@ class VariablePlaceInferencePass : public DebugPass {
} }
return false; return false;
}; };
std::map<std::string, bool> lite_with_targets{ std::map<std::string, bool> with_targets{
{"kOpenCL", valid_places_has_target(TARGET(kOpenCL))}, {"kOpenCL", valid_places_has_target(TARGET(kOpenCL))},
{"kCUDA", valid_places_has_target(TARGET(kCUDA))}, {"kCUDA", valid_places_has_target(TARGET(kCUDA))},
{"kFPGA", valid_places_has_target(TARGET(kFPGA))}}; {"kFPGA", valid_places_has_target(TARGET(kFPGA))}};
VLOG(4) << "lite_with_targets['kOpenCL']:" << lite_with_targets["kOpenCL"]; VLOG(4) << "with_targets['kOpenCL']:" << with_targets["kOpenCL"];
VLOG(4) << "lite_with_targets['kFPGA']:" << lite_with_targets["kFPGA"]; VLOG(4) << "with_targets['kFPGA']:" << with_targets["kFPGA"];
VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->StmtTopologicalOrder()) { for (auto& node : graph->StmtTopologicalOrder()) {
auto& inst = x->AsStmt(); auto& inst = node->AsStmt();
const auto* op_info = inst.op_info();
const auto& op_type = op_info->Type();
auto& kernel = inst.picked_kernel();
// The IoCopyOp is a tool operator, it won't support the type inference. // The IoCopyOp is a tool operator, it won't support the type inference.
// in fpga, we has io_copy+cali+layout tool ops, so we need type inference // in fpga, we has io_copy+cali+layout tool ops, so we need type inference
// for // for tool operator
// tool operator if ((!with_targets["kFPGA"]) && (!with_targets["kOpenCL"])) {
if ((!lite_with_targets["kFPGA"]) && (!lite_with_targets["kOpenCL"])) { VLOG(3) << "skip 'io_copy' if target is FPGA and OpenCL";
VLOG(3) << "inst.op_type() == 'io_copy', continue"; if (op_type == "io_copy") continue;
if (inst.op_type() == "io_copy") continue;
} }
// deal with inputs
VLOG(4) << "Infering op " << inst.op_info()->Repr();
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto get_argname = [&]( // Infering the input and output variable's place according to the
const std::string& node_name, // declaration of I/O arguments of the picked kernel of the op
const std::map<std::string, std::vector<std::string>>& argname_map) VLOG(4) << "Op " << op_info->Repr();
-> std::string { for (auto* in_node : node->inlinks) {
for (auto& ele : argname_map) { auto& var = in_node->AsArg();
auto it = const auto& var_name = var.name;
std::find(ele.second.begin(), ele.second.end(), node_name); auto* var_type = &var.type;
if (it != ele.second.end()) return ele.first; std::string arg_name;
} CHECK(op_info->GetInputArgname(var_name, &arg_name))
return ""; << "Can not find the input argument for var " << var_name;
}; VLOG(4) << " - input arg name:" << arg_name << " var name:" << var_name;
const auto* decl_type = kernel.GetInputDeclType(arg_name);
for (auto* x_in : x->inlinks) { if (!(*var_type)) {
std::string node_name = x_in->AsArg().name; VLOG(4) << "set type " << *decl_type << " " << var_name;
std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); if (var.is_weight) {
CHECK(arg_name.size() > 0) << "can not found op arguments for node " SetWeightType(in_node, *decl_type, with_targets);
<< node_name;
VLOG(4) << "-- input arg_name:" << arg_name << " "
<< "-- node name:" << node_name;
auto type = inst.picked_kernel().GetInputDeclType(arg_name);
if (!x_in->AsArg().type) {
VLOG(4) << "set type " << *type << " " << x_in->AsArg().name;
if (x_in->AsArg().is_weight) {
SetWeightType(x_in, *type, lite_with_targets);
} else { } else {
x_in->AsArg().type = type; *var_type = decl_type;
} }
} else if (x_in->AsArg().type->target() == TARGET(kUnk) && } else if (!(*var_type)->place().is_valid()) {
x_in->AsArg().type->precision() != PRECISION(kUnk) &&
x_in->AsArg().type->layout() == DATALAYOUT(kUnk)) {
// If is quantization, infer the Int8 type. // If is quantization, infer the Int8 type.
if (type->precision() == PRECISION(kInt8)) { if (decl_type->precision() == PRECISION(kInt8)) {
x_in->AsArg().type = type; *var_type = decl_type;
} else { } else {
PrecisionType tmp_ptype = x_in->AsArg().type->precision(); UpdateTypeFrom(var_type, decl_type);
x_in->AsArg().type = LiteType::GetTensorTy(
type->target(), tmp_ptype, type->layout());
} }
} }
} }
for (auto* out_node : node->outlinks) {
VLOG(4) << "inst " << inst.op_info()->Repr(); auto& var = out_node->AsArg();
for (auto* x_out : x->outlinks) { const auto& var_name = var.name;
std::string node_name = x_out->AsArg().name; auto* var_type = &var.type;
std::string arg_name = std::string arg_name;
get_argname(node_name, inst.op_info()->outputs()); CHECK(op_info->GetOutputArgname(var_name, &arg_name))
CHECK(arg_name.size() > 0) << "can not found op arguments for node " << "Can not find the output argument for var " << var_name;
<< node_name << " in Inst " VLOG(4) << " - output arg name:" << arg_name
<< inst.op_type(); << " var name:" << var_name;
VLOG(4) << "-- output arg_name " << arg_name; const auto* decl_type = kernel.GetOutputDeclType(arg_name);
auto type = inst.picked_kernel().GetOutputDeclType(arg_name); if (!(*var_type)) {
if (!x_out->AsArg().type) { VLOG(4) << "set type " << *decl_type << " " << var_name;
VLOG(4) << "set type " << *type << " " << x_out->AsArg().name; if (var.is_weight) {
if (x_out->AsArg().is_weight) { SetWeightType(out_node, *decl_type, with_targets);
SetWeightType(x_out, *type, lite_with_targets);
} else { } else {
x_out->AsArg().type = type; *var_type = decl_type;
} }
} else if (x_out->AsArg().type->target() == TARGET(kUnk) && } else if (!(*var_type)->place().is_valid()) {
x_out->AsArg().type->precision() != PRECISION(kUnk) &&
x_out->AsArg().type->layout() == DATALAYOUT(kUnk)) {
// If is quantization, infer the Int8 type. // If is quantization, infer the Int8 type.
if (type->precision() == PRECISION(kInt8)) { if (decl_type->precision() == PRECISION(kInt8) ||
x_out->AsArg().type = type; (decl_type->precision() == PRECISION(kFP16) &&
} else if (type->precision() == PRECISION(kFP16) && decl_type->target() != TARGET(kOpenCL))) {
type->target() != TARGET(kOpenCL)) { *var_type = decl_type;
x_out->AsArg().type = type;
} else { } else {
PrecisionType tmp_ptype = x_out->AsArg().type->precision(); UpdateTypeFrom(var_type, decl_type);
x_out->AsArg().type = LiteType::GetTensorTy(
type->target(), tmp_ptype, type->layout());
} }
} }
} }
} }
} }
// Update me's kUnk fields by other's fields.
void UpdatePlace(Place* me, const Place& other) {
CHECK(other.is_valid());
if (me->target == TARGET(kUnk)) {
me->target = other.target;
}
if (me->precision == PRECISION(kUnk)) {
me->precision = other.precision;
}
if (me->layout == DATALAYOUT(kUnk)) {
me->layout = other.layout;
}
}
private: private:
// The default target for arguments, e.g. load weights to CPU memory for CUDA // The default target for arguments, e.g. load weights to CPU memory for
// computation by default. // CUDA computation by default.
TargetType argument_default_target_{TARGET(kHost)}; TargetType argument_default_target_{TARGET(kHost)};
}; };
......
...@@ -99,7 +99,7 @@ class OpLite : public Registry { ...@@ -99,7 +99,7 @@ class OpLite : public Registry {
std::vector<std::unique_ptr<KernelBase>> CreateKernels( std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = ""); const std::vector<Place> &places, const std::string &kernel_type = "");
lite::Scope *scope() { return scope_; } Scope *scope() { return scope_; }
// Assign op param to kernel. // Assign op param to kernel.
virtual void AttachKernel(KernelBase *kernel) = 0; virtual void AttachKernel(KernelBase *kernel) = 0;
...@@ -169,7 +169,7 @@ class OpLite : public Registry { ...@@ -169,7 +169,7 @@ class OpLite : public Registry {
} }
protected: protected:
lite::Scope *scope_{nullptr}; Scope *scope_{nullptr};
std::unique_ptr<KernelBase> kernel_; std::unique_ptr<KernelBase> kernel_;
std::string op_type_; std::string op_type_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h"
#include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/generate_program_pass.h"
#include "lite/core/mir/pass_manager.h" #include "lite/core/mir/pass_manager.h"
#include "lite/core/mir/pass_utils.h" #include "lite/core/mir/pass_utils.h"
...@@ -36,6 +37,9 @@ namespace lite { ...@@ -36,6 +37,9 @@ namespace lite {
* lite::Optimizer optimize a program. It utilize the mir passes to analysis the * lite::Optimizer optimize a program. It utilize the mir passes to analysis the
* program and export an optimized program. * program and export an optimized program.
*/ */
// TODO(hong1986032) Support the following passes for the subblocks
const std::set<std::string> kSubblockUnsupportedPasses(
{"memory_optimize_pass"});
class Optimizer { class Optimizer {
public: public:
Optimizer() {} Optimizer() {}
...@@ -60,14 +64,20 @@ class Optimizer { ...@@ -60,14 +64,20 @@ class Optimizer {
program_ = &program; program_ = &program;
valid_places_ = valid_places; valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set"; CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found"; CHECK(graphs_.empty()) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph); auto block_size = program.block_size();
graph_->Build(program, valid_places); for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
graph_->SetValidPlaces(valid_places); std::unique_ptr<mir::SSAGraph> graph;
graph.reset(new mir::SSAGraph);
graph->Build(program, valid_places, block_idx);
graph->SetValidPlaces(valid_places);
graphs_.emplace_back(std::move(graph));
}
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass(); InitTargetTypeTransformPass();
InitControlFlowOpUnusedInputsAndOutputsEliminatePass();
if (passes.empty() || passes.size() == 1) { if (passes.empty() || passes.size() == 1) {
std::vector<std::string> passes_local{ std::vector<std::string> passes_local{
...@@ -76,6 +86,7 @@ class Optimizer { ...@@ -76,6 +86,7 @@ class Optimizer {
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn "lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", // "lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise "lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
"lite_conv_conv_fuse_pass", //
// TODO(Superjomn) Refine the fusion related design to select fusion // TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically. // kernels for devices automatically.
"lite_conv_activation_fuse_pass", // "lite_conv_activation_fuse_pass", //
...@@ -111,6 +122,7 @@ class Optimizer { ...@@ -111,6 +122,7 @@ class Optimizer {
"apu_subgraph_pass", "apu_subgraph_pass",
"rknpu_subgraph_pass", "rknpu_subgraph_pass",
"mlu_subgraph_pass", "mlu_subgraph_pass",
"control_flow_op_unused_inputs_and_outputs_eliminate_pass",
"static_kernel_pick_pass", // pick original kernel from graph "static_kernel_pick_pass", // pick original kernel from graph
"remove_tf_redundant_ops_pass", "remove_tf_redundant_ops_pass",
...@@ -175,62 +187,15 @@ class Optimizer { ...@@ -175,62 +187,15 @@ class Optimizer {
exec_scope_ = program.exec_scope(); exec_scope_ = program.exec_scope();
} }
const lite::Scope* exec_scope() const { return exec_scope_; } const Scope* exec_scope() const { return exec_scope_; }
// Set shape(dims) infos of var descs to scope var.
// developer can write pass using input / output tensor dims of op.
//
// Example: If you have node `Node* softmax_node`,
// you can get dims of output tensor in passes:
//
// auto* scope = softmax_node->AsStmt().op()->scope();
// auto softmax_out_arg_name =
// softmax_node->outlinks.front()->AsArg().name;
// auto softmax_out_tensor =
// scope->FindVar(softmax_out_arg_name)->Get<lite::Tensor>();
// softmax_out_dims = softmax_out_tensor.dims();
void SetVarDescShapeToScopeVar() {
auto dims_to_str_func = [](std::vector<int64_t> shape) -> std::string {
std::string str_res;
for (size_t i = 0; i < shape.size(); ++i) {
str_res += std::to_string(shape[i]);
if (i != shape.size() - 1) {
str_res += "x";
}
}
return str_res;
};
auto* program_desc = program_->program_desc();
VLOG(5) << "program_desc->BlocksSize():" << program_desc->BlocksSize();
auto blocks_desc = program_desc->GetBlocks();
for (size_t bidx = 0; bidx < blocks_desc.size(); ++bidx) {
auto block_desc = blocks_desc[bidx];
auto vars_desc = block_desc.GetVars();
for (size_t vidx = 0; vidx < vars_desc.size(); ++vidx) {
auto var_desc = vars_desc[vidx];
VLOG(5) << var_desc.Name() << " "
<< dims_to_str_func(var_desc.GetShape());
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
auto* var = program_->exec_scope()->FindVar(var_desc.Name());
auto tensor = var->GetMutable<lite::Tensor>();
if (tensor->dims().size() == 0 && var_desc.GetShape().size() != 0) {
VLOG(5) << "var_desc.Name():" << var_desc.Name()
<< " shape:" << dims_to_str_func(var_desc.GetShape());
tensor->Resize(var_desc.GetShape());
}
VLOG(5) << "var_desc.Name():" << var_desc.Name()
<< " shape:" << dims_to_str_func(var_desc.GetShape())
<< " tensor:" << tensor->dims();
}
}
}
// Generate a new program based on the mir graph. // Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() { std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>( auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass"); "generate_program_pass");
pass->Apply(graph_); for (auto& graph : graphs_) {
pass->Apply(graph);
}
auto program = pass->GenProgram(); auto program = pass->GenProgram();
CHECK(exec_scope_); CHECK(exec_scope_);
program->set_exec_scope(exec_scope_); program->set_exec_scope(exec_scope_);
...@@ -246,27 +211,38 @@ class Optimizer { ...@@ -246,27 +211,38 @@ class Optimizer {
pass->SetValidPlaces(valid_places_); pass->SetValidPlaces(valid_places_);
} }
void InitControlFlowOpUnusedInputsAndOutputsEliminatePass() {
auto* pass =
mir::PassManager::Global()
.LookUp<mir::ControlFlowOpUnusedInputsAndOutputsEliminatePass>(
"control_flow_op_unused_inputs_and_outputs_eliminate_pass");
CHECK(pass);
CHECK(!graphs_.empty());
pass->SetAllGraphs(&graphs_);
}
// Generate C++ code which combines the inference program, model and weights. // Generate C++ code which combines the inference program, model and weights.
void GenCode(const std::string& code_dir); void GenCode(const std::string& code_dir);
const mir::SSAGraph& ssa_graph() const { const mir::SSAGraph& ssa_graph(int block_idx = kRootBlockIdx) const {
CHECK(graph_); CHECK(!graphs_.empty());
return *graph_; CHECK(graphs_[block_idx]);
return *graphs_[block_idx];
} }
mir::SSAGraph* mutable_ssa_graph() { mir::SSAGraph* mutable_ssa_graph(int block_idx = kRootBlockIdx) {
CHECK(graph_); CHECK(!graphs_.empty());
return graph_.get(); CHECK(graphs_[block_idx]);
return graphs_[block_idx].get();
} }
lite::Scope* exec_scope() { return exec_scope_; } Scope* exec_scope() { return exec_scope_; }
protected: protected:
void SpecifyKernelPickTactic(core::KernelPickFactor factor); void SpecifyKernelPickTactic(core::KernelPickFactor factor);
// Specify the passes and run them. // Specify the passes and run them.
void RunPasses(const std::vector<std::string>& passes) { void RunPasses(const std::vector<std::string>& passes) {
SetVarDescShapeToScopeVar();
for (auto& x : passes) { for (auto& x : passes) {
LOG(INFO) << "== Running pass: " << x; LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x); mir::Pass* pass = mir::PassManager::Global().LookUp(x);
...@@ -284,16 +260,23 @@ class Optimizer { ...@@ -284,16 +260,23 @@ class Optimizer {
LOG(INFO) << " - Skip " << x LOG(INFO) << " - Skip " << x
<< " because the target or kernel does not match."; << " because the target or kernel does not match.";
} else { } else {
pass->Apply(graph_); // Check the pass whether it is supported for processing subblocks
if (kSubblockUnsupportedPasses.count(x)) {
pass->Apply(graphs_[kRootBlockIdx]);
} else {
for (auto& graph : graphs_) {
pass->Apply(graph);
}
}
LOG(INFO) << "== Finished running: " << x; LOG(INFO) << "== Finished running: " << x;
} }
} }
} }
private: private:
std::unique_ptr<mir::SSAGraph> graph_; std::vector<std::unique_ptr<mir::SSAGraph>> graphs_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
lite::Scope* exec_scope_{}; Scope* exec_scope_{};
Program* program_{}; Program* program_{};
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "lite/core/program.h" #include "lite/core/program.h"
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include <set>
#include "lite/model_parser/cpp_desc.h" #include "lite/model_parser/cpp_desc.h"
#include "lite/operators/conditional_block_op.h" #include "lite/operators/conditional_block_op.h"
#include "lite/operators/subgraph_op.h" #include "lite/operators/subgraph_op.h"
...@@ -26,121 +27,219 @@ ...@@ -26,121 +27,219 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) { void RuntimeProgram::SaveToProgram(
CHECK(desc); std::shared_ptr<cpp::ProgramDesc> program_desc) {
// NOTE: RuntimeProgram do not has all meta info, so save model just update CHECK(program_desc);
// upon origin model auto block_size = program_desc->BlocksSize();
CHECK(desc->BlocksSize()); CHECK_GT(block_size, 0) << "No block found!";
auto main_block = desc->GetBlock<cpp::BlockDesc>(0); // TODD(hong19860320) Only support updating the block desc which already
main_block->ClearOps(); // exists in the origin program desc
for (auto& node : instructions_) { CHECK_LE(block_size, instructions_.size())
auto op_type = node.op()->op_info()->Type(); << "Invalid block size, expected (0," << instructions_.size()
if (op_type == "subgraph") { << "] but got " << block_size;
auto subgraph_op = const_cast<operators::SubgraphOp*>( for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
static_cast<const operators::SubgraphOp*>(node.op())); auto block_desc = program_desc->GetBlock<cpp::BlockDesc>(block_idx);
int sub_block_idx = subgraph_op->op_info()->GetAttr<int32_t>("sub_block"); // Record all of the origin vars in the origin block
if (sub_block_idx < 0) { std::map<std::string, cpp::VarDesc> origin_var_maps;
// It's a new subgraph op when its sub_block_idx < 0, Now we add its auto var_size = block_desc->VarsSize();
for (size_t var_idx = 0; var_idx < var_size; ++var_idx) {
auto v = block_desc->GetVar<cpp::VarDesc>(var_idx);
origin_var_maps.emplace(v->Name(), *v);
}
// Update the ops and vars for each block according to the instructions
block_desc->ClearVars();
block_desc->ClearOps();
std::set<std::string> already_added_vars;
for (auto& inst : instructions_[block_idx]) {
auto* op = const_cast<OpLite*>(inst.op());
auto* op_info = op->op_info();
auto op_type = op_info->Type();
auto* kernel = inst.mutable_kernel();
auto* scope = op->scope();
// Update the origin vars which are referred by the instructions
// Add the new vars which are created in the passes and referred by the
// instructions
auto var_names = op_info->input_names();
auto out_names = op_info->output_names();
// Combine input and output vars and delete the duplicates
var_names.insert(var_names.end(), out_names.begin(), out_names.end());
std::stable_sort(var_names.begin(), var_names.end());
var_names.erase(std::unique(var_names.begin(), var_names.end()),
var_names.end());
for (auto& var_name : var_names) {
if (already_added_vars.count(var_name)) continue;
auto* v = block_desc->AddVar<cpp::VarDesc>();
v->SetName(var_name);
auto it = origin_var_maps.find(var_name);
if (it != origin_var_maps.end()) {
v->SetType(it->second.GetType());
v->SetPersistable(it->second.Persistable());
if (var_name != "feed" && var_name != "fetch") {
v->SetShape(it->second.GetShape());
v->SetDataType(it->second.GetDataType());
}
} else {
std::string arg_name;
const Type* decl_type;
if (op_info->GetInputArgname(var_name, &arg_name)) {
decl_type = kernel->GetInputDeclType(arg_name);
} else {
op_info->GetOutputArgname(var_name, &arg_name);
decl_type = kernel->GetOutputDeclType(arg_name);
}
if (decl_type->IsTensor()) {
v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
auto tensor = scope->FindVar(var_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable());
if (var_name != "feed" && var_name != "fetch") {
v->SetShape(tensor->dims().data());
auto precision = tensor->precision();
switch (precision) {
#define SET_DATATYPE(precision__, data_type) \
case PrecisionType::precision__: \
v->SetDataType(data_type); \
LOG(INFO) << "Update var " << var_name << " done"; \
break
SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL);
SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32);
SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16);
SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8);
SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16);
SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32);
SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64);
#undef SET_DATATYPE
default:
LOG(WARNING) << "Unknown precision type "
<< PrecisionToStr(precision) << " for var "
<< var_name << " in op " << op_type;
}
}
} else if (decl_type->IsTensorList()) {
// Set persistable=false for tensor array
v->SetType(cpp::VarDesc::Type::LOD_TENSOR_ARRAY);
v->SetPersistable(false);
} else {
CHECK(false) << "Unsupported decl type " << *decl_type
<< " for var " << var_name << " in op " << op_type;
}
}
already_added_vars.insert(var_name);
}
// Replace all of origin ops with the instructions
auto op_desc = block_desc->AddOp<cpp::OpDesc>();
*op_desc = *op_info;
op_desc->SetAttr(kKernelTypeAttr, kernel->SerializedKernelType());
if (op_type == "subgraph" && !op_info->GetAttr<int32_t>("sub_block")) {
// It's a new subgraph op when its sub_block_idx = 0, Now we add its
// subblock desc to the program desc, Then update its sub_block_idx to // subblock desc to the program desc, Then update its sub_block_idx to
// the index of block desc of the program desc. // the index of block desc of the program desc.
sub_block_idx = desc->BlocksSize(); auto subgraph_op = static_cast<operators::SubgraphOp*>(op);
auto sub_block_desc = subgraph_op->GetSubBlock(); auto sub_program_desc = subgraph_op->GetProgramDesc();
CHECK(sub_block_desc); CHECK(sub_program_desc);
auto new_block_desc = desc->AddBlock<cpp::BlockDesc>(); auto sub_block_desc = program_desc->AddBlock<cpp::BlockDesc>();
*new_block_desc = *sub_block_desc; *sub_block_desc = *sub_program_desc->GetBlock<cpp::BlockDesc>(0);
delete sub_block_desc; subgraph_op->SetProgramDesc(program_desc);
subgraph_op->mutable_op_info()->SetAttr<int32_t>("sub_block", op_desc->SetAttr<int32_t>("sub_block", program_desc->BlocksSize() - 1);
sub_block_idx); // Attach op and kernel again to update the new block_idx and
subgraph_op->SetSubBlock(new_block_desc); // program_desc
// Update main block desc after a new subblock desc is added subgraph_op->Attach(*op_desc, scope);
main_block = desc->GetBlock<cpp::BlockDesc>(0); subgraph_op->AttachKernel(kernel);
// Update the pointer of block desc after a new subblock desc is added
block_desc = program_desc->GetBlock<cpp::BlockDesc>(block_idx);
} }
} }
auto op = main_block->AddOp<cpp::OpDesc>();
*op = *node.op()->op_info();
op->SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType());
} }
} }
// `UpdateVarsOfProgram` will remove unused var_descs and add new created // Create runtime program from sub_block desc according to block_idx and
// vars' descs in the block 0. Now, the type of a new created var can only // program_desc, which is used for while/conditional_block/subgraph op.
// be LOD_TENSOR. RuntimeProgram::RuntimeProgram(
void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { const std::shared_ptr<const cpp::ProgramDesc>& program_desc,
CHECK(desc); Scope* exec_scope,
CHECK(desc->BlocksSize()); int block_idx)
std::map<std::string, cpp::VarDesc> origin_var_maps; : exec_scope_(exec_scope) {
auto& main_block = *desc->GetBlock<cpp::BlockDesc>(0); #ifdef LITE_WITH_OPENCL
auto var_size = main_block.VarsSize(); using OpenCLContext = Context<TargetType::kOpenCL>;
for (size_t i = 0; i < var_size; i++) { std::unique_ptr<KernelContext> local_ctx(new KernelContext());
auto v = main_block.GetVar<cpp::VarDesc>(i); local_ctx->As<OpenCLContext>().InitOnce();
auto name = v->Name(); #endif
origin_var_maps.emplace(name, *v); CHECK(program_desc);
} auto block_size = program_desc->BlocksSize();
CHECK(block_size) << "No block found!";
main_block.ClearVars(); CHECK(block_idx >= 0 && block_idx < block_size)
for (auto& node : instructions_) { << "Invalid block index, expected [0," << (block_size - 1) << "] but got "
auto* op = const_cast<lite::OpLite*>(node.op()); << block_idx;
auto* kernel = node.kernel(); auto block_desc = program_desc->GetBlock<cpp::BlockDesc>(block_idx);
auto* scope = op->scope(); instructions_.resize(kRootBlockIdx + 1);
auto in_names = op->op_info()->input_names(); auto op_size = block_desc->OpsSize();
auto out_names = op->op_info()->output_names(); for (size_t op_idx = 0; op_idx < op_size; op_idx++) {
in_names.insert(in_names.end(), out_names.begin(), out_names.end()); auto op_desc = block_desc->GetOp<cpp::OpDesc>(op_idx);
std::stable_sort(in_names.begin(), in_names.end()); CHECK(op_desc);
in_names.erase(std::unique(in_names.begin(), in_names.end()), std::string op_type = op_desc->Type();
in_names.end()); // if (op_type == "feed" || op_type == "fetch") continue;
for (auto& in_name : in_names) { // Create op and pick up the best kernel
auto it = origin_var_maps.find(in_name); auto op = LiteOpRegistry::Global().Create(op_type);
if (it != origin_var_maps.end()) { CHECK(op) << "no Op found for " << op_type;
auto* v = main_block.AddVar<cpp::VarDesc>(); if (op_type == "while") {
v->SetName((it->second).Name()); static_cast<operators::WhileOp*>(op.get())->SetProgramDesc(program_desc);
v->SetType((it->second).GetType()); } else if (op_type == "conditional_block") {
v->SetPersistable((it->second).Persistable()); static_cast<operators::ConditionalBlockOp*>(op.get())->SetProgramDesc(
if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") { program_desc);
v->SetShape((it->second).GetShape()); } else if (op_type == "subgraph") {
v->SetDataType((it->second).GetDataType()); static_cast<operators::SubgraphOp*>(op.get())->SetProgramDesc(
} program_desc);
}
op->Attach(*op_desc, exec_scope_);
std::unique_ptr<KernelBase> kernel;
if (op_desc->HasAttr(kKernelTypeAttr)) {
// Create op and pick up the best kernel according to the
// kKernelTypeAttr attribute
auto kernel_type = op_desc->GetAttr<std::string>(kKernelTypeAttr);
std::string alias;
Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
VLOG(3) << "Found the attr '" << kKernelTypeAttr << "': " << kernel_type
<< " for " << op_type;
auto kernels = op->CreateKernels({place});
CHECK_GT(kernels.size(), 0) << "No kernels found for " << op_type;
auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& it) {
return it->alias() == alias;
});
CHECK(it != kernels.end());
kernel = std::move(*it);
} else {
// TODO(hong19860320) add kernel picking according to the type of input
// and output tensors
VLOG(3) << "The attr '" << kKernelTypeAttr
<< "' not found, pick the first kernel for " << op_type;
std::vector<std::unique_ptr<KernelBase>> kernels;
#if defined(LITE_WITH_ARM)
kernels = op->CreateKernels({Place{TARGET(kARM)}, Place{TARGET(kHost)}});
#elif defined(LITE_WITH_X86)
kernels = op->CreateKernels({Place{TARGET(kX86)}, Place{TARGET(kHost)}});
#endif
if (kernels.size() > 0) {
kernel = std::move(kernels.front());
} else { } else {
// New created vars must be LOD_TENSOR LOG(WARNING) << "No kernels found for " << op_type;
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName(in_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string in_arg_name;
const Type* type;
if (op->op_info()->GetInputArgname(in_name, &in_arg_name)) {
type = kernel->GetInputDeclType(in_arg_name);
} else {
op->op_info()->GetOutputArgname(in_name, &in_arg_name);
type = kernel->GetOutputDeclType(in_arg_name);
}
if (type->IsTensor()) {
auto tensor = scope->FindVar(in_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable());
if (in_name != "feed" && in_name != "fetch") {
v->SetShape(tensor->dims().data());
switch (tensor->precision()) {
#define SET_DATATYPE(precision__, data_type) \
case PrecisionType::precision__: \
v->SetDataType(data_type); \
LOG(INFO) << "update var" << (it->second).Name() << "done"; \
break
SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL);
SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32);
SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16);
SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8);
SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16);
SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32);
SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64);
#undef SET_DATATYPE
default:
VLOG(4) << "warning! unknown precision type";
}
}
} else {
CHECK(false) << "unsupported var type";
}
} }
} }
#ifdef LITE_WITH_OPENCL
if (kernel->target() == TARGET(kOpenCL)) {
std::unique_ptr<KernelContext> ctx(new KernelContext());
(*local_ctx).As<OpenCLContext>().CopySharedTo(&ctx->As<OpenCLContext>());
kernel->SetContext(std::move(ctx));
} else {
kernel->SetContext(
ContextScheduler::Global().NewContext(kernel->target()));
}
#else
kernel->SetContext(ContextScheduler::Global().NewContext(kernel->target()));
#endif
instructions_[kRootBlockIdx].emplace_back(std::move(op), std::move(kernel));
} }
Init();
} }
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
...@@ -167,7 +266,8 @@ void RuntimeProgram::Run() { ...@@ -167,7 +266,8 @@ void RuntimeProgram::Run() {
} }
#endif #endif
int idx = -1; int idx = -1;
for (auto& inst : instructions_) { auto& insts = instructions_[kRootBlockIdx];
for (auto& inst : insts) {
++idx; ++idx;
#ifndef LITE_WITH_FPGA #ifndef LITE_WITH_FPGA
if (inst.is_feed_fetch_op()) continue; if (inst.is_feed_fetch_op()) continue;
...@@ -200,58 +300,50 @@ void RuntimeProgram::Run() { ...@@ -200,58 +300,50 @@ void RuntimeProgram::Run() {
#endif #endif
} }
void Program::Build(const cpp::ProgramDesc& prog) { void Program::Build(const std::shared_ptr<cpp::ProgramDesc>& program_desc) {
CHECK(ops_.empty()) << "Executor duplicate Build found"; CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators. // Create operators.
auto& program = prog; auto block_size = program_desc->BlocksSize();
CHECK(program.BlocksSize()); CHECK(block_size);
auto& main_block = *program.GetBlock<cpp::BlockDesc>(0); ops_.resize(block_size);
for (size_t i = 0; i < main_block.OpsSize(); ++i) { for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
auto& op_desc = *main_block.GetOp<cpp::OpDesc>(i); auto* block_desc = program_desc->GetBlock<cpp::BlockDesc>(block_idx);
auto op_type = op_desc.Type(); auto op_size = block_desc->OpsSize();
// if (op_type == "feed" || op_type == "fetch") continue; for (size_t op_idx = 0; op_idx < op_size; ++op_idx) {
VLOG(4) << "create Op [" << op_type << "]"; auto* op_desc = block_desc->GetOp<cpp::OpDesc>(op_idx);
auto op = LiteOpRegistry::Global().Create(op_type); auto op_type = op_desc->Type();
CHECK(op) << "no Op found for " << op_type; VLOG(4) << "create Op [" << op_type << "]";
if (op_type == "while" || op_type == "conditional_block" || auto op = LiteOpRegistry::Global().Create(op_type);
op_type == "subgraph") { CHECK(op) << "no Op found for " << op_type;
auto sub_block_idx = op_desc.GetAttr<int32_t>("sub_block");
CHECK(sub_block_idx >= 0 &&
sub_block_idx < static_cast<int>(program.BlocksSize()))
<< "Invalid attribute sub_block(" << sub_block_idx << ") for "
<< op_type;
auto sub_block_desc =
const_cast<cpp::ProgramDesc&>(prog).GetBlock<cpp::BlockDesc>(
sub_block_idx);
CHECK(sub_block_desc);
if (op_type == "while") { if (op_type == "while") {
static_cast<operators::WhileOpLite*>(op.get())->SetSubBlock( static_cast<operators::WhileOp*>(op.get())->SetProgramDesc(
sub_block_desc); program_desc);
} else if (op_type == "conditional_block") { } else if (op_type == "conditional_block") {
static_cast<operators::ConditionalBlockOpLite*>(op.get())->SetSubBlock( static_cast<operators::ConditionalBlockOp*>(op.get())->SetProgramDesc(
sub_block_desc); program_desc);
} else if (op_type == "subgraph") { } else if (op_type == "subgraph") {
static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock( static_cast<operators::SubgraphOp*>(op.get())->SetProgramDesc(
sub_block_desc); program_desc);
} }
op->Attach(*op_desc, exec_scope_);
ops_[block_idx].emplace_back(std::move(op));
} }
ops_.emplace_back(std::move(op));
ops_.back()->Attach(op_desc, exec_scope_);
} }
} }
void Program::PrepareWorkspace(const cpp::ProgramDesc& prog, void Program::PrepareWorkspace(
const std::vector<std::string>& var_names) { const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::vector<std::string>& vars_to_clone) {
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found"; CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope(); exec_scope_ = &scope_->NewScope();
// Create Feed and Fetch var. // Create Feed and Fetch var.
scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>(); scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>(); scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
tmp_vars_.push_back("feed"); vars_.push_back("feed");
tmp_vars_.push_back("fetch"); vars_.push_back("fetch");
auto VarPrecision2KernlPrecision = auto VarDescType2PrecisionType =
[](const lite::VarDescAPI::Type& type) -> PrecisionType { [](const lite::VarDescAPI::Type& type) -> PrecisionType {
switch (type) { switch (type) {
case lite::VarDescAPI::Type::FP32: case lite::VarDescAPI::Type::FP32:
...@@ -267,44 +359,60 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog, ...@@ -267,44 +359,60 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog,
case lite::VarDescAPI::Type::INT64: case lite::VarDescAPI::Type::INT64:
return PRECISION(kInt64); return PRECISION(kInt64);
default: default:
// LOG(FATAL) << "not supported type: " << static_cast<int>(type); LOG(WARNING) << "Unable to convert var desc type("
<< static_cast<int>(type) << ") to precision type!";
return PRECISION(kUnk); return PRECISION(kUnk);
} }
}; };
auto& program = prog; auto block_size = program_desc->BlocksSize();
CHECK(program.BlocksSize()); CHECK(block_size);
for (size_t b = 0; b < program.BlocksSize(); ++b) { for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
auto& main_block = *program.GetBlock<cpp::BlockDesc>(b); auto* block_desc = program_desc->GetBlock<cpp::BlockDesc>(block_idx);
for (size_t i = 0; i < main_block.VarsSize(); ++i) { auto var_size = block_desc->VarsSize();
auto& var_desc = *main_block.GetVar<cpp::VarDesc>(i); for (size_t var_idx = 0; var_idx < var_size; ++var_idx) {
if (!var_desc.Persistable()) { auto* var_desc = block_desc->GetVar<cpp::VarDesc>(var_idx);
if (var_desc.GetType() == lite::VarDescAPI::Type::LOD_TENSOR && const auto& var_name = var_desc->Name();
VarPrecision2KernlPrecision(var_desc.GetDataType()) != const auto& var_type = var_desc->GetType();
PRECISION(kUnk)) { if (!var_desc->Persistable()) {
var_data_type_[var_desc.Name()] = vars_.push_back(var_name);
VarPrecision2KernlPrecision(var_desc.GetDataType()); auto* var = exec_scope_->Var(var_name);
} VLOG(4) << "Var " << var_name << " in block " << block_idx;
tmp_vars_.push_back(var_desc.Name()); VLOG(4) << " - type " << static_cast<int>(var_type);
VLOG(4) << "var name: " << var_desc.Name() << " type is " if (var_type == lite::VarDescAPI::Type::LOD_TENSOR) {
<< static_cast<int>(var_desc.GetType()) << " data type is " const auto& var_data_type =
<< static_cast<int>(var_desc.GetDataType()); VarDescType2PrecisionType(var_desc->GetDataType());
exec_scope_->Var(var_desc.Name()); if (var_data_type != PRECISION(kUnk)) {
if (b > 0) { var_type_map_[var_name] = LiteType::GetTensorTy(
VLOG(4) << "var: " << var_desc.Name(); TARGET(kUnk), var_data_type, DATALAYOUT(kUnk));
}
VLOG(4) << " - data type " << static_cast<int>(var_data_type);
// Create the tensor with the shape from var desc, it's convenient to
// the graph analysis in the passes, but you should resize the tensor
// with the real shape before accessing its data, because the
// var_shape may be [-1,3,224,224]
const auto& var_shape = var_desc->GetShape();
auto* tensor = var->GetMutable<lite::Tensor>();
if (tensor->dims().empty() && !var_shape.empty()) {
tensor->Resize(var_shape);
VLOG(4) << " - dims " << tensor->dims().repr();
}
} else if (var_type == lite::VarDescAPI::Type::LOD_TENSOR_ARRAY) {
var_type_map_[var_name] = LiteType::GetTensorListTy(
TARGET(kUnk), PRECISION(kUnk), DATALAYOUT(kUnk));
} }
} else { } else {
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; if (var_name == "feed" || var_name == "fetch") continue;
weights_.push_back(var_desc.Name()); weights_.push_back(var_name);
if (var_desc.Persistable()) scope_->Var(var_desc.Name()); scope_->Var(var_name);
} }
} }
} }
for (auto i : var_names) { for (auto var_name : vars_to_clone) {
exec_scope_->LocalVar(i); exec_scope_->LocalVar(var_name);
auto* tensor = scope_->Var(i)->GetMutable<lite::Tensor>(); auto* tensor = scope_->Var(var_name)->GetMutable<Tensor>();
auto* sub_tensor = exec_scope_->Var(i)->GetMutable<lite::Tensor>(); auto* sub_tensor = exec_scope_->Var(var_name)->GetMutable<Tensor>();
sub_tensor->CopyDataFrom(*tensor); sub_tensor->CopyDataFrom(*tensor);
} }
} }
......
...@@ -41,61 +41,72 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__"; ...@@ -41,61 +41,72 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__";
// - scope: which contains all the weights // - scope: which contains all the weights
struct Program { struct Program {
public: public:
explicit Program(const std::shared_ptr<Scope>& root) { scope_ = root; } explicit Program(const std::shared_ptr<Scope>& root_scope) {
Program(const cpp::ProgramDesc& desc, scope_ = root_scope;
const std::shared_ptr<Scope>& root, }
Program(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root_scope,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {}) const std::vector<std::string>& vars_to_clone = {})
: scope_(root), valid_places_(valid_places) { : scope_(root_scope),
desc_.CopyFrom(desc); valid_places_(valid_places),
program_desc_(program_desc) {
CHECK(scope_) << "scope should be init first"; CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work"; VLOG(4) << "prepare work";
PrepareWorkspace(desc, var_names); PrepareWorkspace(program_desc_, vars_to_clone);
VLOG(4) << "build desc"; VLOG(4) << "build desc";
Build(desc); Build(program_desc_);
VLOG(4) << "build desc finished"; VLOG(4) << "build desc finished";
} }
std::unique_ptr<Program> Clone() const { std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(desc_, scope_, valid_places_)); return std::unique_ptr<Program>(
return res; new Program(program_desc_, scope_, valid_places_));
} }
const std::list<std::string>& weights() const { return weights_; } const std::list<std::string>& weights() const { return weights_; }
const std::list<std::string>& tmp_vars() const { return tmp_vars_; } const std::list<std::string>& vars() const { return vars_; }
std::list<std::string>* mutable_weights() { return &weights_; } std::list<std::string>* mutable_weights() { return &weights_; }
std::list<std::string>* mutable_tmp_vars() { return &tmp_vars_; } std::list<std::string>* mutable_vars() { return &vars_; }
const std::list<std::shared_ptr<OpLite>>& ops() const { return ops_; } const std::list<std::shared_ptr<OpLite>>& ops(
std::list<std::shared_ptr<OpLite>>* mutable_ops() { return &ops_; } int block_idx = kRootBlockIdx) const {
return ops_[block_idx];
}
std::list<std::shared_ptr<OpLite>>* mutable_ops(
int block_idx = kRootBlockIdx) {
return &ops_[block_idx];
}
lite::Scope* exec_scope() { return exec_scope_; } size_t block_size() { return ops_.size(); }
lite::Scope* scope() { return scope_.get(); }
cpp::ProgramDesc* program_desc() { return &desc_; } Scope* exec_scope() { return exec_scope_; }
Scope* scope() { return scope_.get(); }
const std::map<std::string, PrecisionType>& var_data_type() const { cpp::ProgramDesc* program_desc() { return program_desc_.get(); }
return var_data_type_;
const std::map<std::string, const Type*>& var_type_map() const {
return var_type_map_;
} }
private: private:
// Build from a program and scope. // Build from a program and scope.
void Build(const cpp::ProgramDesc& program); void Build(const std::shared_ptr<cpp::ProgramDesc>& program_desc);
// Create temporary variables. // Create temporary variables.
void PrepareWorkspace(const cpp::ProgramDesc& program, void PrepareWorkspace(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::vector<std::string>& var_names = {}); const std::vector<std::string>& vars_to_clone = {});
private: private:
std::map<std::string, PrecisionType> var_data_type_; std::map<std::string, const Type*> var_type_map_;
std::list<std::string> tmp_vars_; std::list<std::string> vars_;
std::list<std::string> weights_; std::list<std::string> weights_;
std::list<std::shared_ptr<OpLite>> ops_; std::vector<std::list<std::shared_ptr<OpLite>>> ops_;
// the scope to run the kernels, NOTE this is the execution scope. // the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope_; std::shared_ptr<Scope> scope_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
// Runtime scope. // Runtime scope.
lite::Scope* exec_scope_{}; Scope* exec_scope_{};
cpp::ProgramDesc desc_; std::shared_ptr<cpp::ProgramDesc> program_desc_;
}; };
struct Instruction { struct Instruction {
...@@ -179,8 +190,22 @@ struct Instruction { ...@@ -179,8 +190,22 @@ struct Instruction {
*/ */
class LITE_API RuntimeProgram { class LITE_API RuntimeProgram {
public: public:
explicit RuntimeProgram(std::vector<Instruction>&& insts) explicit RuntimeProgram(std::vector<std::vector<Instruction>>&& insts)
: instructions_(std::move(insts)) { : instructions_(std::move(insts)) {
Init();
}
explicit RuntimeProgram(
const std::shared_ptr<const cpp::ProgramDesc>& program_desc,
Scope* exec_scope,
int block_idx = kRootBlockIdx);
~RuntimeProgram() {
#ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kCreate);
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch);
#endif // LITE_WITH_PROFILE
}
void Init() {
if (instructions_.empty()) { if (instructions_.empty()) {
LOG(FATAL) << "no instructions"; LOG(FATAL) << "no instructions";
} }
...@@ -189,7 +214,7 @@ class LITE_API RuntimeProgram { ...@@ -189,7 +214,7 @@ class LITE_API RuntimeProgram {
#endif #endif
#ifdef LITE_WITH_NVTX #ifdef LITE_WITH_NVTX
const NVTXAnnotator& annotator = NVTXAnnotator::Global(); const NVTXAnnotator& annotator = NVTXAnnotator::Global();
for (auto& inst : instructions_) { for (auto& inst : instructions_[kRootBlockIdx]) {
NVTXRangeAnnotation annotation = annotator.AnnotateBlock(); NVTXRangeAnnotation annotation = annotator.AnnotateBlock();
register_layer_names_.push_back(annotator.RegisterString( register_layer_names_.push_back(annotator.RegisterString(
const_cast<paddle::lite::OpLite*>(inst.op())->Type().c_str())); const_cast<paddle::lite::OpLite*>(inst.op())->Type().c_str()));
...@@ -197,30 +222,27 @@ class LITE_API RuntimeProgram { ...@@ -197,30 +222,27 @@ class LITE_API RuntimeProgram {
register_layer_names_.push_back(annotator.RegisterString("one_loop")); register_layer_names_.push_back(annotator.RegisterString("one_loop"));
#endif #endif
} }
~RuntimeProgram() {
#ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kCreate);
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch);
#endif // LITE_WITH_PROFILE
}
void Run(); void Run();
void set_exec_scope(lite::Scope* x) { exec_scope_ = x; } void set_exec_scope(Scope* x) { exec_scope_ = x; }
lite::Scope* exec_scope() { return exec_scope_; } Scope* exec_scope() { return exec_scope_; }
size_t num_instructions() const { return instructions_.size(); } const std::vector<Instruction>& instructions(
int block_idx = kRootBlockIdx) const {
return instructions_[block_idx];
}
const std::vector<Instruction>& instructions() const { return instructions_; } std::vector<Instruction>* mutable_instructions(
int block_idx = kRootBlockIdx) {
return &instructions_[block_idx];
}
// `SaveOpInfosToProgram` will update the op list(ops_) of the block 0 size_t block_size() { return instructions_.size(); }
// in ProgramDesc.
void SaveOpInfosToProgram(cpp::ProgramDesc* desc);
// `UpdateVarsOfProgram` will update the var list(vars_) of the block 0 in // Update the ops and vars of all of blocks to the given program_desc
// ProgramDesc. Namely, if a new var created in some passes, its var_desc will // according to the instructions
// be added in vars_. void SaveToProgram(std::shared_ptr<cpp::ProgramDesc> program_desc);
void UpdateVarsOfProgram(cpp::ProgramDesc* desc);
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
// UpdateCudaContext will update the exec stream and io stream of all kernels // UpdateCudaContext will update the exec stream and io stream of all kernels
...@@ -230,14 +252,14 @@ class LITE_API RuntimeProgram { ...@@ -230,14 +252,14 @@ class LITE_API RuntimeProgram {
private: private:
RuntimeProgram(const RuntimeProgram&) = delete; RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_; std::vector<std::vector<Instruction>> instructions_;
lite::Scope* exec_scope_{}; Scope* exec_scope_{};
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
profile::Profiler profiler_; profile::Profiler profiler_;
void set_profiler() { void set_profiler() {
for (auto i = instructions_.begin(); i != instructions_.end(); ++i) { for (auto& inst : instructions_[kRootBlockIdx]) {
i->set_profiler(&profiler_); inst.set_profiler(&profiler_);
} }
} }
#endif #endif
......
...@@ -37,7 +37,7 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -37,7 +37,7 @@ bool SubgraphEngine::BuildDeviceProgram() {
subgraph::apu::Graph graph; subgraph::apu::Graph graph;
int neuron_errCode = NeuronModel_create(&model_); int neuron_errCode = NeuronModel_create(&model_);
if (NEURON_NO_ERROR != neuron_errCode) { if (NEURON_NO_ERROR != neuron_errCode) {
LOG(WARNING) << "Fail to create model"; LOG(WARNING) << "[APU] Failed to create the neuron model!";
return false; return false;
} }
graph.set_model(model_); graph.set_model(model_);
...@@ -46,11 +46,12 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -46,11 +46,12 @@ bool SubgraphEngine::BuildDeviceProgram() {
// Convert all of ops and their input vars and weights and added into the APU // Convert all of ops and their input vars and weights and added into the APU
// NIR graph // NIR graph
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) { const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = const_cast<OpLite*>(inst.op()); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
...@@ -70,55 +71,38 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -70,55 +71,38 @@ bool SubgraphEngine::BuildDeviceProgram() {
} }
} }
// Get input tensor // Get the index of input tensors
std::vector<uint32_t> ins; std::vector<uint32_t> input_indices;
origin_itensors_.resize(input_names_.size());
origin_idims_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) { for (int i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); CHECK(graph.Has(input_names_[i])) << "[APU] Failed to find input node "
CHECK(origin_itensors_[i]); << input_names_[i];
origin_idims_[i] = origin_itensors_[i]->dims(); auto index = graph.Get(input_names_[i])->index();
VLOG(3) << "subgraph input name: " << i << ", " << input_names_[i] << ":" input_indices.push_back(index);
<< origin_idims_[i].production(); VLOG(3) << "[APU] Input[" << i << "] name " << input_names_[i] << " dims "
// Get input index << origin_itensors_[i]->dims() << " index " << index;
int idx;
if (graph.Has(input_names_[i])) {
ins.push_back(graph.Get(input_names_[i])->index());
VLOG(3) << "input idx: " << graph.Get(input_names_[i])->index();
} else {
LOG(WARNING) << "Fail to find input: " << input_names_[i];
return false;
}
} }
// Get output tensor // Get the index of output tensors
std::vector<uint32_t> outs; std::vector<uint32_t> output_indices;
origin_otensors_.resize(output_names_.size());
origin_odims_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) { for (int i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); CHECK(graph.Has(output_names_[i])) << "[APU] Failed to find output node "
CHECK(origin_otensors_[i]); << output_names_[i];
origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "subgraph output name: " << i << ", " << output_names_[i] << ":"
<< origin_odims_[i].production();
origin_otensors_[i]->mutable_data<int8_t>(); origin_otensors_[i]->mutable_data<int8_t>();
// Get input index auto index = graph.Get(output_names_[i])->index();
if (graph.Has(output_names_[i])) { output_indices.push_back(index);
outs.push_back(graph.Get(output_names_[i])->index()); VLOG(3) << "[APU] Output[" << i << "] name " << output_names_[i] << " dims "
VLOG(3) << "output idx: " << graph.Get(output_names_[i])->index(); << origin_otensors_[i]->dims() << " index " << index;
} else {
LOG(WARNING) << "Fail to find output: " << output_names_[i];
return false;
}
} }
VLOG(3) << "ins size: " << ins.size() << " outs size:" << outs.size(); // Indentify the input and output tensors of the neuron model
// Set subgraph input/output NeuronModel_identifyInputsAndOutputs(model_,
NeuronModel_identifyInputsAndOutputs( input_indices.size(),
model_, ins.size(), &ins[0], outs.size(), &outs[0]); &input_indices[0],
output_indices.size(),
&output_indices[0]);
neuron_errCode = NeuronModel_finish(model_); neuron_errCode = NeuronModel_finish(model_);
if (NEURON_NO_ERROR != neuron_errCode) { if (NEURON_NO_ERROR != neuron_errCode) {
LOG(WARNING) << "Fail to create NIR model:" << neuron_errCode; LOG(WARNING) << "[APU] Fail to create NIR model:" << neuron_errCode;
return false; return false;
} }
VLOG(3) << "[APU] APU NIR model created!"; VLOG(3) << "[APU] APU NIR model created!";
...@@ -207,11 +191,11 @@ SubgraphEngine::~SubgraphEngine() { ...@@ -207,11 +191,11 @@ SubgraphEngine::~SubgraphEngine() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(), engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx, param.block_idx,
param.sub_block_desc, param.program_desc,
param.exec_scope,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names));
param.scope));
CHECK(engine_); CHECK(engine_);
} }
......
...@@ -31,12 +31,16 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -31,12 +31,16 @@ class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext *ctx, SubgraphEngine(KernelContext *ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, const std::shared_ptr<const cpp::ProgramDesc> &program_desc,
Scope *exec_scope,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names)
Scope *scope) : subgraph::Engine(ctx,
: subgraph::Engine( block_idx,
ctx, block_idx, block_desc, input_names, output_names, scope) {} program_desc,
exec_scope,
input_names,
output_names) {}
~SubgraphEngine(); ~SubgraphEngine();
......
...@@ -75,7 +75,6 @@ add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_comp ...@@ -75,7 +75,6 @@ add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_comp
add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -87,7 +86,6 @@ add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_comp ...@@ -87,7 +86,6 @@ add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_comp
add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/program.h"
#include "lite/operators/conditional_block_op.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h"
#include "lite/core/profile/precision_profiler.h"
#include "lite/core/profile/profiler.h"
#endif
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class CondExecutor {
typedef std::shared_ptr<OpLite> OpPtr;
public:
CondExecutor(cpp::BlockDesc *block, Scope *scope, Place place)
: scope_(scope), place_(place) {
int32_t op_size = block->OpsSize();
for (int32_t i = 0; i < op_size; ++i) {
auto &op_desc = *block->template GetOp<cpp::OpDesc>(i);
auto op_type = op_desc.Type();
auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type());
op_handler->Attach(op_desc, scope);
auto hostplace = place_;
hostplace.target = TARGET(kHost);
auto kernels = op_handler->CreateKernels({place_, hostplace});
CHECK_GT(kernels.size(), 0) << "cannot create kernel";
op_handler->AttachKernel(kernels[0].get());
op_handler->SetKernel(kernels);
ops_of_block_.push_back(op_handler);
}
}
void Run() {
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
lite::profile::Profiler profiler;
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
for (auto &op_handler : ops_of_block_) {
op_handler->CheckShape();
op_handler->InferShape();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
std::unique_ptr<KernelBase> kernel(op_handler->GetKernel());
Instruction inst(op_handler, std::move(kernel));
inst.set_profiler(&profiler);
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
op_handler->Run();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
LITE_PRECISION_PROFILE(inst)
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
}
}
private:
Scope *scope_;
Place place_;
std::vector<OpPtr> ops_of_block_;
};
class ConditionalBlockCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ConditionalBlockParam;
void PrepareForRun() override;
void Run() override;
virtual ~ConditionalBlockCompute() = default;
private:
std::shared_ptr<CondExecutor> executor_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -202,17 +202,13 @@ void ElementwiseMulCompute<T, PType>::Run() { ...@@ -202,17 +202,13 @@ void ElementwiseMulCompute<T, PType>::Run() {
} }
} }
template <> template <typename T, PrecisionType PType>
void ElementwiseMulCompute<int64_t, PRECISION(kInt64)>::Run() { void ElementwiseMulActivationCompute<T, PType>::Run() {
auto& param = this->template Param<operators::ElementwiseParam>(); auto& param =
lite::arm::math::elementwise_compute_basic<int64_t>(param, "mul", ""); this->template Param<operators::FusionElementwiseActivationParam>();
} auto* x_data = param.X->template data<T>();
auto* y_data = param.Y->template data<T>();
void ElementwiseMulActivationCompute::Run() { auto* out_data = param.Out->template mutable_data<T>();
auto& param = Param<operators::FusionElementwiseActivationParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int axis = param.axis; int axis = param.axis;
std::string act_type = param.act_type; std::string act_type = param.act_type;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
...@@ -221,21 +217,21 @@ void ElementwiseMulActivationCompute::Run() { ...@@ -221,21 +217,21 @@ void ElementwiseMulActivationCompute::Run() {
if (x_dims.size() < y_dims.size() && if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_mul_relu_broadcast<float>( lite::arm::math::elementwise_mul_relu_broadcast<T>(
y_data, x_data, out_data, pre, n, post); y_data, x_data, out_data, pre, n, post);
} else { } else {
LOG(FATAL) << "unsupported Activation type: " << act_type; LOG(FATAL) << "unsupported Activation type: " << act_type;
} }
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_mul_relu_broadcast( lite::arm::math::elementwise_mul_relu_broadcast<T>(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
} else { } else {
LOG(FATAL) << "unsupported Activation type: " << act_type; LOG(FATAL) << "unsupported Activation type: " << act_type;
} }
} else { } else {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_mul_relu( lite::arm::math::elementwise_mul_relu<T>(
x_data, y_data, out_data, x_dims.production()); x_data, y_data, out_data, x_dims.production());
} else { } else {
LOG(FATAL) << "unsupported Activation type: " << act_type; LOG(FATAL) << "unsupported Activation type: " << act_type;
...@@ -426,46 +422,60 @@ REGISTER_LITE_KERNEL( ...@@ -426,46 +422,60 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_mul_float = using elementwise_mul_float_t =
paddle::lite::kernels::arm::ElementwiseMulCompute<float, PRECISION(kFloat)>; paddle::lite::kernels::arm::ElementwiseMulCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float, def) elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_mul_int32 = using elementwise_mul_int32_t =
paddle::lite::kernels::arm::ElementwiseMulCompute<int, PRECISION(kInt32)>; paddle::lite::kernels::arm::ElementwiseMulCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32, def) elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize(); .Finalize();
using elementwise_mul_int64 = using elementwise_mul_int64_t =
paddle::lite::kernels::arm::ElementwiseMulCompute<int64_t, paddle::lite::kernels::arm::ElementwiseMulCompute<int64_t,
PRECISION(kInt64)>; PRECISION(kInt64)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64, def) elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( using fusion_elementwise_mul_activation_float_t = paddle::lite::kernels::arm::
fusion_elementwise_mul_activation, ElementwiseMulActivationCompute<float, PRECISION(kFloat)>;
kARM, REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
kFloat, kARM,
kNCHW, kFloat,
paddle::lite::kernels::arm::ElementwiseMulActivationCompute, kNCHW,
def) fusion_elementwise_mul_activation_float_t,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using fusion_elementwise_mul_activation_int64_t = paddle::lite::kernels::arm::
ElementwiseMulActivationCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
kARM,
kInt64,
kNCHW,
fusion_elementwise_mul_activation_int64_t,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_max, REGISTER_LITE_KERNEL(elementwise_max,
kARM, kARM,
kFloat, kFloat,
...@@ -489,22 +499,22 @@ REGISTER_LITE_KERNEL( ...@@ -489,22 +499,22 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_div_fp32 = using elementwise_div_fp32_t =
paddle::lite::kernels::arm::ElementwiseDivCompute<float, PRECISION(kFloat)>; paddle::lite::kernels::arm::ElementwiseDivCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32, def) elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_div_int64 = using elementwise_div_int64_t =
paddle::lite::kernels::arm::ElementwiseDivCompute<int64_t, paddle::lite::kernels::arm::ElementwiseDivCompute<int64_t,
PRECISION(kInt64)>; PRECISION(kInt64)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64, def) elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
...@@ -522,11 +532,11 @@ REGISTER_LITE_KERNEL( ...@@ -522,11 +532,11 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_mod_int64 = using elementwise_mod_int64_t =
paddle::lite::kernels::arm::ElementwiseModCompute<int64_t, paddle::lite::kernels::arm::ElementwiseModCompute<int64_t,
PRECISION(kInt64)>; PRECISION(kInt64)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
elementwise_mod, kARM, kInt64, kNCHW, elementwise_mod_int64, def) elementwise_mod, kARM, kInt64, kNCHW, elementwise_mod_int64_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
......
...@@ -62,8 +62,8 @@ class ElementwiseMulCompute : public KernelLite<TARGET(kARM), PType> { ...@@ -62,8 +62,8 @@ class ElementwiseMulCompute : public KernelLite<TARGET(kARM), PType> {
virtual ~ElementwiseMulCompute() = default; virtual ~ElementwiseMulCompute() = default;
}; };
class ElementwiseMulActivationCompute template <typename T, PrecisionType PType>
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class ElementwiseMulActivationCompute : public KernelLite<TARGET(kARM), PType> {
public: public:
void Run() override; void Run() override;
......
...@@ -533,13 +533,15 @@ TEST(fusion_elementwise_mul_activation_arm, retrive_op) { ...@@ -533,13 +533,15 @@ TEST(fusion_elementwise_mul_activation_arm, retrive_op) {
} }
TEST(fusion_elementwise_mul_activation_arm, init) { TEST(fusion_elementwise_mul_activation_arm, init) {
ElementwiseMulActivationCompute fusion_elementwise_mul_activation; ElementwiseMulActivationCompute<float, PRECISION(kFloat)>
fusion_elementwise_mul_activation;
ASSERT_EQ(fusion_elementwise_mul_activation.precision(), PRECISION(kFloat)); ASSERT_EQ(fusion_elementwise_mul_activation.precision(), PRECISION(kFloat));
ASSERT_EQ(fusion_elementwise_mul_activation.target(), TARGET(kARM)); ASSERT_EQ(fusion_elementwise_mul_activation.target(), TARGET(kARM));
} }
TEST(fusion_elementwise_mul_activation_arm, compute) { TEST(fusion_elementwise_mul_activation_arm, compute) {
ElementwiseMulActivationCompute fusion_elementwise_mul_activation; ElementwiseMulActivationCompute<float, PRECISION(kFloat)>
fusion_elementwise_mul_activation;
operators::FusionElementwiseActivationParam param; operators::FusionElementwiseActivationParam param;
lite::Tensor x, y, output, output_ref; lite::Tensor x, y, output, output_ref;
......
...@@ -20,44 +20,45 @@ namespace lite { ...@@ -20,44 +20,45 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename T> template <typename IndexType, typename DataType>
void GatherFunc(const operators::GatherParam& param) { void GatherFunc(const operators::GatherParam& param) {
auto src_dims = param.X->dims(); auto src_dims = param.X->dims();
auto index_size = param.Index->dims()[0]; auto index_size = param.Index->dims()[0];
auto* p_src = param.X->data<T>(); auto* p_src = param.X->data<DataType>();
const int* p_index = param.Index->data<int>(); const IndexType* p_index = param.Index->data<IndexType>();
auto* p_output = param.Out->mutable_data<T>(); auto* p_output = param.Out->mutable_data<DataType>();
int slice_size = 1; int slice_size = 1;
for (size_t i = 1; i < src_dims.size(); ++i) { for (size_t i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i]; slice_size *= src_dims[i];
} }
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
int index_ = p_index[i]; IndexType index_ = p_index[i];
memcpy(p_output + i * slice_size, memcpy(p_output + i * slice_size,
p_src + index_ * slice_size, p_src + index_ * slice_size,
slice_size * sizeof(T)); slice_size * sizeof(DataType));
} }
} }
void GatherCompute::Run() { template <typename IndexType>
auto& param = this->Param<operators::GatherParam>(); void GatherCompute<IndexType>::Run() {
auto& param = this->template Param<operators::GatherParam>();
switch (param.X->precision()) { switch (param.X->precision()) {
case PRECISION(kFloat): case PRECISION(kFloat):
GatherFunc<float>(param); GatherFunc<IndexType, float>(param);
break; break;
case PRECISION(kInt8): case PRECISION(kInt8):
GatherFunc<int8_t>(param); GatherFunc<IndexType, int8_t>(param);
break; break;
case PRECISION(kInt16): case PRECISION(kInt16):
GatherFunc<int16_t>(param); GatherFunc<IndexType, int16_t>(param);
break; break;
case PRECISION(kInt32): case PRECISION(kInt32):
GatherFunc<int32_t>(param); GatherFunc<IndexType, int32_t>(param);
break; break;
case PRECISION(kInt64): case PRECISION(kInt64):
GatherFunc<int64_t>(param); GatherFunc<IndexType, int64_t>(param);
break; break;
default: default:
LOG(FATAL) << "Gather does not implement for the " LOG(FATAL) << "Gather does not implement for the "
...@@ -70,9 +71,26 @@ void GatherCompute::Run() { ...@@ -70,9 +71,26 @@ void GatherCompute::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(gather,
gather, kARM, kAny, kNCHW, paddle::lite::kernels::arm::GatherCompute, def) kARM,
kAny,
kNCHW,
paddle::lite::kernels::arm::GatherCompute<int32_t>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(gather,
kARM,
kAny,
kNCHW,
paddle::lite::kernels::arm::GatherCompute<int64_t>,
def_int64_idx)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
...@@ -23,6 +23,7 @@ namespace lite { ...@@ -23,6 +23,7 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename IndexType>
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public: public:
void Run() override; void Run() override;
......
...@@ -102,10 +102,14 @@ void SequenceConvCompute::Run() { ...@@ -102,10 +102,14 @@ void SequenceConvCompute::Run() {
1, 1,
1, // stride_h, stride_w, dilation_h, dilation_w 1, // stride_h, stride_w, dilation_h, dilation_w
tmp_data); tmp_data);
local_naive_transpose(tmp_data, int cols = kernel_size * hidden_dim;
sub_col_data, int rows = input_row_end - input_row_begin;
kernel_size * hidden_dim, if (cols % 4 == 0 && rows % 4 == 0) {
input_row_end - input_row_begin); paddle::lite::arm::math::local_transpose(
tmp_data, sub_col_data, cols, rows);
} else {
local_naive_transpose(tmp_data, sub_col_data, cols, rows);
}
} }
} }
......
...@@ -28,36 +28,17 @@ namespace lite { ...@@ -28,36 +28,17 @@ namespace lite {
namespace kernels { namespace kernels {
namespace bm { namespace bm {
bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() {
// Obtain the origin input tensors, and create the origin output
// tensors(Don't try to access them before launch the device program or the
// origin program)
PrepareWorkspaceForOriginProgram();
// Create the device input and output tensors, but don't initialize them
// with the dimensions
device_inputs_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) {
device_inputs_[i].reset(new hiai::AiTensor);
CHECK(device_inputs_[i]);
}
device_outputs_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) {
device_outputs_[i].reset(new hiai::AiTensor);
CHECK(device_outputs_[i]);
}
return true;
}
bool SubgraphEngine::BuildDeviceProgram() { bool SubgraphEngine::BuildDeviceProgram() {
int status = 0; int status = 0;
subgraph::bm::Graph graph; subgraph::bm::Graph graph;
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
graph.CreateCompilerHandle(); graph.CreateCompilerHandle();
auto& ctx = this->ctx_->template As<BMContext>(); auto& ctx = this->ctx_->template As<BMContext>();
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
for (auto& inst : origin_program_) { const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = const_cast<OpLite*>(inst.op()); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
...@@ -93,13 +74,11 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -93,13 +74,11 @@ bool SubgraphEngine::BuildDeviceProgram() {
net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]); net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]);
auto& stage = net_info_->stages[0]; auto& stage = net_info_->stages[0];
// input // input
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
device_inputs_.resize(input_names_.size()); device_inputs_.resize(input_names_.size());
for (size_t i = 0; i < input_names_.size(); i++) { for (size_t i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(net_info_->input_names[i]); origin_itensors_[i] =
exec_scope_->FindMutableTensor(net_info_->input_names[i]);
CHECK(origin_itensors_[i]); CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
bm_device_mem_t* p_mem = bm_device_mem_t* p_mem =
static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t))); static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
CHECK(p_mem != nullptr); CHECK(p_mem != nullptr);
...@@ -112,8 +91,6 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -112,8 +91,6 @@ bool SubgraphEngine::BuildDeviceProgram() {
stage.input_shapes[i]); stage.input_shapes[i]);
} }
// output // output
origin_odims_.resize(output_names_.size());
origin_otensors_.resize(output_names_.size());
device_outputs_.resize(net_info_->output_num); device_outputs_.resize(net_info_->output_num);
int out_index = 0; int out_index = 0;
for (int i = 0; i < output_names_.size(); i++) { for (int i = 0; i < output_names_.size(); i++) {
...@@ -121,14 +98,13 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -121,14 +98,13 @@ bool SubgraphEngine::BuildDeviceProgram() {
} }
for (int i = 0; i < net_info_->output_num; i++) { for (int i = 0; i < net_info_->output_num; i++) {
Tensor* t_cur = scope_->FindMutableTensor(net_info_->output_names[i]); Tensor* t_cur = exec_scope_->FindMutableTensor(net_info_->output_names[i]);
CHECK(t_cur != nullptr); CHECK(t_cur != nullptr);
bm_device_mem_t* p_mem = bm_device_mem_t* p_mem =
static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t))); static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
CHECK(p_mem != nullptr); CHECK(p_mem != nullptr);
if (outname_map_.find(net_info_->output_names[i]) != outname_map_.end()) { if (outname_map_.find(net_info_->output_names[i]) != outname_map_.end()) {
origin_otensors_[out_index] = t_cur; origin_otensors_[out_index] = t_cur;
origin_odims_[out_index] = origin_otensors_[out_index]->dims();
origin_otensors_[out_index]->mutable_data<float>(); origin_otensors_[out_index]->mutable_data<float>();
out_index += 1; out_index += 1;
} }
...@@ -173,11 +149,11 @@ bool SubgraphEngine::LaunchDeviceProgram() { ...@@ -173,11 +149,11 @@ bool SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(), engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx, param.block_idx,
param.sub_block_desc, param.program_desc,
param.exec_scope,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names));
param.scope));
CHECK(engine_); CHECK(engine_);
} }
......
...@@ -36,15 +36,18 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -36,15 +36,18 @@ class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext *ctx, SubgraphEngine(KernelContext *ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, const std::shared_ptr<const cpp::ProgramDesc> &program_desc,
Scope *exec_scope,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names)
Scope *scope) : subgraph::Engine(ctx,
: subgraph::Engine( block_idx,
ctx, block_idx, block_desc, input_names, output_names, scope) {} program_desc,
exec_scope,
input_names,
output_names) {}
protected: protected:
bool PrepareWorkspaceForDeviceProgram() override;
bool BuildDeviceProgram() override; bool BuildDeviceProgram() override;
bool LaunchDeviceProgram() override; bool LaunchDeviceProgram() override;
......
...@@ -18,6 +18,9 @@ add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute. ...@@ -18,6 +18,9 @@ add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.
add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps}) add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps})
add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps}) add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps})
add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps}) add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps})
add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kernel_deps})
add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps}) add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
......
...@@ -51,3 +51,19 @@ REGISTER_LITE_KERNEL( ...@@ -51,3 +51,19 @@ REGISTER_LITE_KERNEL(
PRECISION(kAny), PRECISION(kAny),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(assign,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::AssignCompute,
def_tensor_array)
.BindInput("X",
{LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
...@@ -12,28 +12,21 @@ ...@@ -12,28 +12,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/conditional_block_compute.h" #include "lite/kernels/host/conditional_block_compute.h"
#include <memory>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void ConditionalBlockCompute::PrepareForRun() { void ConditionalBlockCompute::PrepareForRun() {
auto& param = Param<operators::ConditionalBlockParam>(); auto& param = this->Param<param_t>();
auto cur_scope = param.scope; program_.reset(new RuntimeProgram(
param.program_desc, param.exec_scope, param.block_idx));
executor_ =
std::make_shared<CondExecutor>(param.sub_block, cur_scope, place());
} }
void ConditionalBlockCompute::Run() { void ConditionalBlockCompute::Run() {
auto& param = Param<operators::ConditionalBlockParam>(); auto& param = this->Param<param_t>();
for (auto& out : param.outs) { for (auto& out : param.outs) {
out->clear(); out->clear();
} }
...@@ -43,32 +36,40 @@ void ConditionalBlockCompute::Run() { ...@@ -43,32 +36,40 @@ void ConditionalBlockCompute::Run() {
auto* cond_data = cond->data<bool>(); auto* cond_data = cond->data<bool>();
need_run = cond_data[0]; need_run = cond_data[0];
} else { } else {
auto x = param.x; for (auto input : param.inputs) {
for (auto pt : x) { if (input == nullptr || !input->IsInitialized() ||
if (pt == nullptr || !pt->IsInitialized() || pt->dims().empty()) { input->dims().empty()) {
need_run = false; need_run = false;
break; break;
} }
} }
} }
if (need_run) { if (need_run) {
executor_->Run(); program_->Run();
} }
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(conditional_block, REGISTER_LITE_KERNEL(conditional_block,
kARM, kHost,
kFloat, kAny,
kNCHW, kAny,
paddle::lite::kernels::arm::ConditionalBlockCompute, paddle::lite::kernels::host::ConditionalBlockCompute,
def) def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input",
.BindInput("Cond", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) {LiteType::GetTensorListTy(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Scope", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Cond",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorListTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Scope",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize(); .Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/program.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class ConditionalBlockCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::ConditionalBlockParam;
void PrepareForRun() override;
void Run() override;
private:
std::unique_ptr<RuntimeProgram> program_;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/print_compute.h"
#include <mutex> // NOLINT
#include <string>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
const char kForward[] = "FORWARD";
const char kBackward[] = "BACKWARD";
const char kBoth[] = "BOTH";
class TensorFormatter {
public:
TensorFormatter() {}
std::string Format(const Tensor& print_tensor,
const std::string& tensor_name = "",
const std::string& message = "") {
std::stringstream log_stream;
if (!tensor_name.empty()) {
log_stream << "Variable: " << tensor_name << std::endl;
}
if (!message.empty()) {
log_stream << " - message: " << message << std::endl;
}
if (print_tensor_lod_) {
log_stream << " - lod: {";
const LoD& lod = print_tensor.lod();
for (auto level : lod) {
log_stream << "{";
bool is_first = true;
for (auto i : level) {
if (is_first) {
log_stream << i;
is_first = false;
} else {
log_stream << ", " << i;
}
}
log_stream << "}";
}
log_stream << "}" << std::endl;
}
log_stream << " - place: " << TargetToStr(print_tensor.target())
<< std::endl; // TODO(hong19860320) always kHost
if (print_tensor_shape_) {
log_stream << " - shape: " << print_tensor.dims().repr() << std::endl;
}
if (print_tensor_layout_) {
log_stream << " - layout: "
<< DataLayoutToStr(
DATALAYOUT(kNCHW)) // TODO(hong19860320) Query the data
// layout from target tensor
<< std::endl;
}
auto dtype = print_tensor.precision();
if (print_tensor_type_) {
log_stream << " - dtype: " << PrecisionToStr(dtype) << std::endl;
}
if (dtype == PRECISION(kBool)) {
FormatData<bool>(print_tensor, log_stream);
} else if (dtype == PRECISION(kInt8)) {
FormatData<int8_t>(print_tensor, log_stream);
} else if (dtype == PRECISION(kInt16)) {
FormatData<int16_t>(print_tensor, log_stream);
} else if (dtype == PRECISION(kInt32)) {
FormatData<int32_t>(print_tensor, log_stream);
} else if (dtype == PRECISION(kInt64)) {
FormatData<int64_t>(print_tensor, log_stream);
} else if (dtype == PRECISION(kFloat)) {
FormatData<float>(print_tensor, log_stream);
} else {
log_stream << "\tdata: unprintable type: " << PrecisionToStr(dtype)
<< std::endl;
}
return log_stream.str();
}
void Print(const Tensor& print_tensor,
const std::string& tensor_name = "",
const std::string& message = "") {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::cout << Format(print_tensor, tensor_name, message);
}
void SetPrintTensorType(bool print_tensor_type) {
print_tensor_type_ = print_tensor_type;
}
void SetPrintTensorShape(bool print_tensor_shape) {
print_tensor_shape_ = print_tensor_shape;
}
void SetPrintTensorLod(bool print_tensor_lod) {
print_tensor_lod_ = print_tensor_lod;
}
void SetPrintTensorLayout(bool print_tensor_layout) {
print_tensor_layout_ = print_tensor_layout;
}
void SetSummarize(int64_t summarize) { summarize_ = summarize; }
private:
template <typename T>
void FormatData(const Tensor& print_tensor, std::stringstream& log_stream) {
int64_t print_size = summarize_ == -1
? print_tensor.numel()
: std::min(summarize_, print_tensor.numel());
const T* data = print_tensor.data<T>(); // Always kHost, so unnessary to
// copy the data from device
log_stream << " - data: [";
if (print_size > 0) {
log_stream << data[0];
for (int64_t i = 1; i < print_size; ++i) {
log_stream << " " << data[i];
}
}
log_stream << "]" << std::endl;
}
int64_t summarize_ = -1;
bool print_tensor_type_ = true;
bool print_tensor_shape_ = true;
bool print_tensor_lod_ = true;
bool print_tensor_layout_ = true;
};
void PrintCompute::Run() {
auto& param = Param<param_t>();
param.out->CopyDataFrom(*param.in);
if ((param.is_forward && param.print_phase == kBackward) ||
(!param.is_forward && param.print_phase == kForward)) {
return;
}
int first_n = param.first_n;
if (first_n > 0 && ++times_ > first_n) return;
TensorFormatter formatter;
const std::string& name = param.print_tensor_name ? param.name : "";
formatter.SetPrintTensorType(param.print_tensor_type);
formatter.SetPrintTensorShape(param.print_tensor_shape);
formatter.SetPrintTensorLod(param.print_tensor_lod);
formatter.SetPrintTensorLayout(param.print_tensor_layout);
formatter.SetSummarize(static_cast<int64_t>(param.summarize));
formatter.Print(*param.in, name, param.message);
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
print, kHost, kAny, kAny, paddle::lite::kernels::host::PrintCompute, def)
.BindInput("In",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class PrintCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::PrintParam;
void Run() override;
virtual ~PrintCompute() = default;
private:
mutable int times_{0};
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -12,44 +12,44 @@ ...@@ -12,44 +12,44 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/while_compute.h" #include "lite/kernels/host/while_compute.h"
#include <memory> #include <unordered_map>
#include <string> #include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void WhileCompute::PrepareForRun() { void WhileCompute::PrepareForRun() {
auto &param = Param<operators::WhileParam>(); auto &param = this->Param<param_t>();
auto cur_scope = param.scope; program_.reset(new RuntimeProgram(
param.program_desc, param.exec_scope, param.block_idx));
executor_ =
std::make_shared<StepExecutor>(param.sub_block, cur_scope, place());
} }
void WhileCompute::Run() { void WhileCompute::Run() {
auto &param = Param<operators::WhileParam>(); auto &param = this->Param<param_t>();
while (param.cond->data<bool>()[0]) { while (param.cond->data<bool>()[0]) {
executor_->Run(); program_->Run();
} }
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
while, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::WhileCompute, def) while, kHost, kAny, kAny, paddle::lite::kernels::host::WhileCompute, def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X",
{LiteType::GetTensorListTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("Condition", .BindInput("Condition",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) {LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) {LiteType::GetTensorListTy(
.BindOutput("StepScopes", {LiteType::GetTensorTy(TARGET(kARM))}) TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("StepScopes",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize(); .Finalize();
...@@ -15,56 +15,19 @@ ...@@ -15,56 +15,19 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/operators/while_op.h" #include "lite/core/program.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
class StepExecutor { class WhileCompute
typedef std::shared_ptr<OpLite> OpPtr; : public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
StepExecutor(cpp::BlockDesc *block, Scope *scope, Place place)
: scope_(scope), place_(place) {
int32_t op_size = block->OpsSize();
for (int32_t i = 0; i < op_size; ++i) {
auto &op_desc = *block->template GetOp<cpp::OpDesc>(i);
auto op_type = op_desc.Type();
auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type());
// VLOG(4) << "while: creating Op [" << op_type << "]";
op_handler->Attach(op_desc, scope);
auto hostplace = place_;
hostplace.target = TARGET(kHost);
auto kernels = op_handler->CreateKernels({place_, hostplace});
CHECK_GT(kernels.size(), 0) << "cannot create kernel";
op_handler->AttachKernel(kernels[0].get());
op_handler->SetKernel(kernels);
ops_of_block_.push_back(op_handler);
}
}
void Run() {
for (auto &op_handler : ops_of_block_) {
// VLOG(4) << op_handler->op_info()->Repr();
op_handler->InferShape();
// VLOG(4) << "while: infered shape";
op_handler->Run();
}
}
private:
Scope *scope_;
Place place_;
std::vector<OpPtr> ops_of_block_;
};
class WhileCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
using param_t = operators::WhileParam; using param_t = operators::WhileParam;
...@@ -74,10 +37,10 @@ class WhileCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -74,10 +37,10 @@ class WhileCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~WhileCompute() = default; virtual ~WhileCompute() = default;
private: private:
std::shared_ptr<StepExecutor> executor_; std::unique_ptr<RuntimeProgram> program_;
}; };
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -43,13 +43,17 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -43,13 +43,17 @@ class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext* ctx, SubgraphEngine(KernelContext* ctx,
int block_idx, int block_idx,
cpp::BlockDesc* block_desc, const std::shared_ptr<const cpp::ProgramDesc>& program_desc,
Scope* exec_scope,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
Scope* scope,
paddle::lite_api::PrecisionType type) paddle::lite_api::PrecisionType type)
: subgraph::Engine( : subgraph::Engine(ctx,
ctx, block_idx, block_desc, input_names, output_names, scope), block_idx,
program_desc,
exec_scope,
input_names,
output_names),
fp_type_(type) { fp_type_(type) {
VLOG(4) << "[MLU] PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL is " VLOG(4) << "[MLU] PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL is "
<< GetBoolFromEnv("PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL"); << GetBoolFromEnv("PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL");
...@@ -103,7 +107,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -103,7 +107,7 @@ class SubgraphEngine : public subgraph::Engine {
protected: protected:
bool BuildDeviceProgram() override { bool BuildDeviceProgram() override {
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
if (!error_compile_batch_size_changeable_ && if (!error_compile_batch_size_changeable_ &&
...@@ -128,13 +132,15 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -128,13 +132,15 @@ class SubgraphEngine : public subgraph::Engine {
origin_itensors_.clear(); origin_itensors_.clear();
origin_otensors_.clear(); origin_otensors_.clear();
auto data_order = block_desc_->GetOp<cpp::OpDesc>(0)->Type() == "layout" auto* sub_block_desc =
program_desc_->GetBlock()<cpp::BlockDesc>(block_idx_);
auto data_order = sub_block_desc->GetOp<cpp::OpDesc>(0)->Type() == "layout"
? CNML_NCHW ? CNML_NCHW
: CNML_NHWC; : CNML_NHWC;
// Convert all of input data vars and added into the MLU IR graph // Convert all of input data vars and added into the MLU IR graph
status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED; status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED;
for (auto& input_name : input_names_) { for (auto& input_name : input_names_) {
auto input_tensor = scope_->FindMutableTensor(input_name); auto input_tensor = exec_scope_->FindMutableTensor(input_name);
auto data_type = input_tensor->precision(); auto data_type = input_tensor->precision();
cnmlDataType_t fp_type = PrecisionToDatatype(data_type); cnmlDataType_t fp_type = PrecisionToDatatype(data_type);
origin_itensors_.push_back(input_tensor); origin_itensors_.push_back(input_tensor);
...@@ -161,7 +167,8 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -161,7 +167,8 @@ class SubgraphEngine : public subgraph::Engine {
LOG(INFO) << "START TO CONVERT "; LOG(INFO) << "START TO CONVERT ";
// Convert all of ops and its weights and added into the MLU IR graph // Convert all of ops and its weights and added into the MLU IR graph
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) { const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = inst.op(); auto op = inst.op();
CHECK(op); CHECK(op);
std::string op_type = op->op_info()->Type(); std::string op_type = op->op_info()->Type();
...@@ -200,7 +207,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -200,7 +207,7 @@ class SubgraphEngine : public subgraph::Engine {
for (auto& output_name : output_names_) { for (auto& output_name : output_names_) {
if (graph->HasNode(output_name)) { if (graph->HasNode(output_name)) {
graph->AddOutput(graph->GetNode(output_name)); graph->AddOutput(graph->GetNode(output_name));
auto output_tensor = scope_->FindMutableTensor(output_name); auto output_tensor = exec_scope_->FindMutableTensor(output_name);
origin_otensors_.push_back(output_tensor); origin_otensors_.push_back(output_tensor);
VLOG(4) << "subgraph output tensor " << output_name << std::endl; VLOG(4) << "subgraph output tensor " << output_name << std::endl;
...@@ -257,7 +264,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -257,7 +264,7 @@ class SubgraphEngine : public subgraph::Engine {
for (const auto& input_name : input_names_) { for (const auto& input_name : input_names_) {
tmp = input_name; tmp = input_name;
name += TrimStrings(tmp) + delimiter + input_shape_str; name += TrimStrings(tmp) + delimiter + input_shape_str;
auto input_tensor = scope_->FindMutableTensor(input_name); auto input_tensor = exec_scope_->FindMutableTensor(input_name);
for (const auto& iterm : input_tensor->dims().Vectorize()) { for (const auto& iterm : input_tensor->dims().Vectorize()) {
name += std::to_string(iterm) + delimiter_num; name += std::to_string(iterm) + delimiter_num;
} }
...@@ -266,7 +273,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -266,7 +273,7 @@ class SubgraphEngine : public subgraph::Engine {
for (const auto& output_name : output_names_) { for (const auto& output_name : output_names_) {
tmp = output_name; tmp = output_name;
name += TrimStrings(tmp) + delimiter + output_shape_str; name += TrimStrings(tmp) + delimiter + output_shape_str;
auto output_tensor = scope_->FindMutableTensor(output_name); auto output_tensor = exec_scope_->FindMutableTensor(output_name);
for (const auto& iterm : output_tensor->dims().Vectorize()) { for (const auto& iterm : output_tensor->dims().Vectorize()) {
name += std::to_string(iterm) + delimiter_num; name += std::to_string(iterm) + delimiter_num;
} }
...@@ -284,7 +291,8 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -284,7 +291,8 @@ class SubgraphEngine : public subgraph::Engine {
origin_otensors_[i]->Resize(iter->second[i]); origin_otensors_[i]->Resize(iter->second[i]);
} }
} else { } else {
for (auto& inst : origin_program_) { const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = inst.op(); auto op = inst.op();
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
...@@ -475,11 +483,11 @@ class SubgraphCompute ...@@ -475,11 +483,11 @@ class SubgraphCompute
auto& param = this->template Param<param_t>(); auto& param = this->template Param<param_t>();
// LOG(INFO) << "SUBGRAP Prepare RUN index " << param.sub_block_idx; // LOG(INFO) << "SUBGRAP Prepare RUN index " << param.sub_block_idx;
engine_.reset(new SubgraphEngine<Precision>(this->ctx_.get(), engine_.reset(new SubgraphEngine<Precision>(this->ctx_.get(),
param.sub_block_idx, param.block_idx,
param.sub_block_desc, param.program_desc,
param.exec_scope,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names,
param.scope,
this->precision())); this->precision()));
CHECK(engine_); CHECK(engine_);
} }
......
...@@ -25,11 +25,14 @@ namespace subgraph { ...@@ -25,11 +25,14 @@ namespace subgraph {
Engine::Engine(KernelContext *ctx, Engine::Engine(KernelContext *ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, const std::shared_ptr<const cpp::ProgramDesc> &program_desc,
Scope *exec_scope,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names)
lite::Scope *scope) : ctx_(ctx),
: ctx_(ctx), block_idx_(block_idx), block_desc_(block_desc), scope_(scope) { block_idx_(block_idx),
program_desc_(program_desc),
exec_scope_(exec_scope) {
input_names_ = input_names; input_names_ = input_names;
output_names_ = output_names; output_names_ = output_names;
// Sort the name of input and output tensors, it's convenient for us to get // Sort the name of input and output tensors, it's convenient for us to get
...@@ -55,12 +58,12 @@ bool Engine::PrepareWorkspaceForOriginProgram() { ...@@ -55,12 +58,12 @@ bool Engine::PrepareWorkspaceForOriginProgram() {
origin_idims_.resize(input_names_.size()); origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size()); origin_itensors_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) { for (int i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); origin_itensors_[i] = exec_scope_->FindMutableTensor(input_names_[i]);
CHECK(origin_itensors_[i]); CHECK(origin_itensors_[i]);
} }
origin_otensors_.resize(output_names_.size()); origin_otensors_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) { for (int i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); origin_otensors_[i] = exec_scope_->FindMutableTensor(output_names_[i]);
CHECK(origin_otensors_[i]); CHECK(origin_otensors_[i]);
} }
return true; return true;
...@@ -69,70 +72,20 @@ bool Engine::PrepareWorkspaceForOriginProgram() { ...@@ -69,70 +72,20 @@ bool Engine::PrepareWorkspaceForOriginProgram() {
bool Engine::BuildOriginProgram() { bool Engine::BuildOriginProgram() {
// TODO(hong19860320) The block_desc need to be divided into subgraphs during // TODO(hong19860320) The block_desc need to be divided into subgraphs during
// the exection time. But only see them as a subgraph now. // the exection time. But only see them as a subgraph now.
origin_program_.clear(); if (!origin_program_) {
for (size_t op_idx = 0; op_idx < block_desc_->OpsSize(); op_idx++) { origin_program_.reset(
auto op_desc = block_desc_->GetOp<cpp::OpDesc>(op_idx); new RuntimeProgram(program_desc_, exec_scope_, block_idx_));
CHECK(op_desc);
std::string op_type = op_desc->Type();
// Create op and pick up the best kernel
auto op = LiteOpRegistry::Global().Create(op_desc->Type());
CHECK(op) << "no Op found for " << op_type;
op->Attach(*op_desc, scope_);
std::unique_ptr<KernelBase> picked_kernel;
if (op_desc->HasAttr(kKernelTypeAttr)) {
// Create op and pick up the best kernel according to the
// kKernelTypeAttr attribute
auto kernel_type = op_desc->GetAttr<std::string>(kKernelTypeAttr);
std::string alias;
Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
VLOG(3) << "Found the attr '" << kKernelTypeAttr << "': " << kernel_type
<< " for " << op_type;
auto kernels = op->CreateKernels({place});
CHECK_GT(kernels.size(), 0u) << "No kernels found for " << op_type;
auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase> &it) {
return it->alias() == alias;
});
CHECK(it != kernels.end());
picked_kernel = std::move(*it);
} else {
// TODO(hong19860320) add kernel picking according to the type of input
// and output tensors
VLOG(3) << "The attr '" << kKernelTypeAttr
<< "' not found, pick the first kernel for " << op_type;
std::vector<std::unique_ptr<KernelBase>> kernels;
#if defined(LITE_WITH_ARM)
kernels = op->CreateKernels({Place{TARGET(kARM)}, Place{TARGET(kHost)}});
#elif defined(LITE_WITH_X86)
kernels = op->CreateKernels({Place{TARGET(kX86)}, Place{TARGET(kHost)}});
#endif
if (kernels.size() > 0) {
picked_kernel = std::move(kernels.front());
} else {
LOG(WARNING) << "No kernels found for " << op_type;
}
}
if (picked_kernel != nullptr) {
picked_kernel->SetContext(
ContextScheduler::Global().NewContext(picked_kernel->target()));
}
origin_program_.emplace_back(std::move(op), std::move(picked_kernel));
} }
CHECK(!origin_program_.empty()) << "no instructions";
return true; return true;
} }
bool Engine::LaunchOriginProgram() { bool Engine::LaunchOriginProgram() {
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
if (!origin_program_.empty()) { if (origin_program_) {
for (auto &inst : origin_program_) { VLOG(3) << "Roll back to run the origin program.";
auto op_type = inst.op()->op_info()->Type(); origin_program_->Run();
if (op_type == "feed" || op_type == "fetch") continue;
inst.Run();
}
return true; return true;
} }
return false; return false;
......
...@@ -30,10 +30,10 @@ class Engine { ...@@ -30,10 +30,10 @@ class Engine {
public: public:
Engine(KernelContext *ctx, Engine(KernelContext *ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, const std::shared_ptr<const cpp::ProgramDesc> &program_desc,
Scope *exec_scope,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names);
lite::Scope *scope);
virtual ~Engine() = default; virtual ~Engine() = default;
virtual bool Run(); virtual bool Run();
...@@ -54,15 +54,15 @@ class Engine { ...@@ -54,15 +54,15 @@ class Engine {
KernelContext *ctx_{nullptr}; KernelContext *ctx_{nullptr};
int block_idx_{-1}; int block_idx_{-1};
cpp::BlockDesc *block_desc_{nullptr}; const std::shared_ptr<const cpp::ProgramDesc> program_desc_{nullptr};
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
Scope *scope_{nullptr}; Scope *exec_scope_{nullptr};
bool is_first_epoch_{true}; bool is_first_epoch_{true};
std::vector<std::vector<int64_t>> origin_idims_; std::vector<std::vector<int64_t>> origin_idims_;
std::vector<Tensor *> origin_itensors_; std::vector<Tensor *> origin_itensors_;
std::vector<Tensor *> origin_otensors_; std::vector<Tensor *> origin_otensors_;
std::vector<Instruction> origin_program_; std::unique_ptr<RuntimeProgram> origin_program_{nullptr};
}; };
} // namespace subgraph } // namespace subgraph
......
...@@ -55,7 +55,8 @@ std::string DeviceProgram::GenerateModelName( ...@@ -55,7 +55,8 @@ std::string DeviceProgram::GenerateModelName(
} }
// Deserialize the generated model, the precisions and dimensions of the origin // Deserialize the generated model, the precisions and dimensions of the origin
// output tensors of the subgraph op into files // output tensors of the subgraph op from the cached configuration file and HiAI
// om file
bool DeviceProgram::LoadFromCacheFile( bool DeviceProgram::LoadFromCacheFile(
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
...@@ -71,7 +72,7 @@ bool DeviceProgram::LoadFromCacheFile( ...@@ -71,7 +72,7 @@ bool DeviceProgram::LoadFromCacheFile(
VLOG(3) << "[NPU] Load model from " << model_path; VLOG(3) << "[NPU] Load model from " << model_path;
std::vector<char> model_buffer; std::vector<char> model_buffer;
if (!ReadFile(model_path, &model_buffer)) { if (!ReadFile(model_path, &model_buffer)) {
LOG(WARNING) << "[NPU] read from " << model_path << " failed!"; LOG(WARNING) << "[NPU] Open " << model_path << " for reading failed!";
return false; return false;
} }
bool model_comp = false; bool model_comp = false;
...@@ -98,9 +99,9 @@ bool DeviceProgram::LoadFromCacheFile( ...@@ -98,9 +99,9 @@ bool DeviceProgram::LoadFromCacheFile(
LOG(WARNING) << "[NPU] read from " << config_path << " failed!"; LOG(WARNING) << "[NPU] read from " << config_path << " failed!";
return false; return false;
} }
std::string config_str(config_buffer.begin(), config_buffer.end()); std::string str(config_buffer.begin(), config_buffer.end());
// Parse the precision and shapes of the output tensors // Parse the precision and shapes of the output tensors
auto output_options = Split<std::string>(config_str, ";"); auto output_options = Split<std::string>(str, ";");
CHECK_EQ(output_options.size(), output_names.size()); CHECK_EQ(output_options.size(), output_names.size());
origin_otypes_.resize(output_names.size()); origin_otypes_.resize(output_names.size());
origin_odims_.resize(output_names.size()); origin_odims_.resize(output_names.size());
...@@ -114,7 +115,7 @@ bool DeviceProgram::LoadFromCacheFile( ...@@ -114,7 +115,7 @@ bool DeviceProgram::LoadFromCacheFile(
} }
bool DeviceProgram::BuildGraphAndCacheToFile( bool DeviceProgram::BuildGraphAndCacheToFile(
const std::vector<Instruction>& origin_program, RuntimeProgram* origin_program,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims, const std::vector<std::vector<int64_t>>& origin_idims,
...@@ -127,10 +128,13 @@ bool DeviceProgram::BuildGraphAndCacheToFile( ...@@ -127,10 +128,13 @@ bool DeviceProgram::BuildGraphAndCacheToFile(
// Convert all of ops and their input vars and weights to HiAI IR nodes, // Convert all of ops and their input vars and weights to HiAI IR nodes,
// then added them into the HiAI IR graph // then added them into the HiAI IR graph
int status = 0; int status = 0;
CHECK(!origin_program.empty()) << "no instructions";
subgraph::npu::Graph graph; subgraph::npu::Graph graph;
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program) { CHECK(origin_program) << "[NPU] The origin program is not initialized!";
CHECK_GT(origin_program->instructions(kRootBlockIdx).size(), 0)
<< "[NPU] No instructions found in the origin program!";
const auto& insts = origin_program->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = const_cast<OpLite*>(inst.op()); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
...@@ -149,7 +153,8 @@ bool DeviceProgram::BuildGraphAndCacheToFile( ...@@ -149,7 +153,8 @@ bool DeviceProgram::BuildGraphAndCacheToFile(
// Collect the input and output nodes of the HiAI IR graph // Collect the input and output nodes of the HiAI IR graph
std::vector<ge::Operator> device_inodes; std::vector<ge::Operator> device_inodes;
for (size_t i = 0; i < input_names.size(); i++) { for (size_t i = 0; i < input_names.size(); i++) {
CHECK(graph.Has(input_names[i]) && graph.Get(input_names[i])->is_data()); CHECK(graph.Has(input_names[i]));
CHECK(graph.Get(input_names[i])->is_data());
device_inodes.push_back(*graph.Get(input_names[i])->data()); device_inodes.push_back(*graph.Get(input_names[i])->data());
} }
std::vector<ge::Operator> device_onodes; std::vector<ge::Operator> device_onodes;
...@@ -173,6 +178,9 @@ bool DeviceProgram::BuildGraphAndCacheToFile( ...@@ -173,6 +178,9 @@ bool DeviceProgram::BuildGraphAndCacheToFile(
LOG(WARNING) << "[NPU] Load model failed!"; LOG(WARNING) << "[NPU] Load model failed!";
return false; return false;
} }
// Do not check model compatibility because it assume that the cached om model
// is always compatible with the current device
// Update the precison and dimensions of the origin output tensors
// Update the precison and dimensions of the origin output tensors // Update the precison and dimensions of the origin output tensors
CHECK_EQ(origin_otensors.size(), output_names.size()); CHECK_EQ(origin_otensors.size(), output_names.size());
origin_otypes_.resize(output_names.size()); origin_otypes_.resize(output_names.size());
...@@ -247,7 +255,7 @@ bool DeviceProgram::ShareBufferWithOriginTensors( ...@@ -247,7 +255,7 @@ bool DeviceProgram::ShareBufferWithOriginTensors(
device_idims_[i].GetHeight() * device_idims_[i].GetWidth()); device_idims_[i].GetHeight() * device_idims_[i].GetWidth());
VLOG(3) << "[NPU] Init the input tensors for the device program and share " VLOG(3) << "[NPU] Init the input tensors for the device program and share "
"their buffers with the origin input tensors"; "their buffers with the origin input tensors";
// reinit device tensor will free shared buffer, so copy data to a tmp // Reinit device tensor will free shared buffer, so copy data to a tmp
// tensor // tensor
Tensor tmp; Tensor tmp;
tmp.CopyDataFrom(*(*origin_itensors)[i]); tmp.CopyDataFrom(*(*origin_itensors)[i]);
...@@ -337,8 +345,9 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -337,8 +345,9 @@ bool SubgraphEngine::BuildDeviceProgram() {
if (!device_programs_.count(origin_idims_)) { if (!device_programs_.count(origin_idims_)) {
auto device_program = std::make_shared<DeviceProgram>(); auto device_program = std::make_shared<DeviceProgram>();
// Obtain the model cache dir from the NPU Context of the subgraph op // Obtain the model cache dir from the NPU Context of the subgraph op
auto model_cache_dir = ctx_->As<NPUContext>().SubgraphModelCacheDir(); auto model_cache_dir =
VLOG(3) << "[NPU] Getting subgraph model_cache_dir is: " << model_cache_dir; ctx_->As<NPUContext>().SubgraphModelCacheDir(exec_scope_);
VLOG(3) << "[NPU] Getting subgraph_model_cache_dir: " << model_cache_dir;
// Check and load if the cached model and configuration file exists // Check and load if the cached model and configuration file exists
if (model_cache_dir.empty() || if (model_cache_dir.empty() ||
!device_program->LoadFromCacheFile( !device_program->LoadFromCacheFile(
...@@ -346,11 +355,13 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -346,11 +355,13 @@ bool SubgraphEngine::BuildDeviceProgram() {
// Build the model online, including converting the paddle ops to the HiAI // Build the model online, including converting the paddle ops to the HiAI
// IR nodes, building the HiAI IR graph to the om model, then load it as a // IR nodes, building the HiAI IR graph to the om model, then load it as a
// new HiAI model manager client for inference. // new HiAI model manager client for inference.
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
CHECK(!origin_program_.empty()) << "no instructions"; CHECK(origin_program_) << "[NPU] The origin program is not initialized!";
if (!device_program->BuildGraphAndCacheToFile(origin_program_, CHECK_GT(origin_program_->instructions().size(), 0)
<< "[NPU] No instructions found in the origin program!";
if (!device_program->BuildGraphAndCacheToFile(origin_program_.get(),
input_names_, input_names_,
output_names_, output_names_,
origin_idims_, origin_idims_,
...@@ -391,11 +402,11 @@ bool SubgraphEngine::LaunchDeviceProgram() { ...@@ -391,11 +402,11 @@ bool SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(), engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx, param.block_idx,
param.sub_block_desc, param.program_desc,
param.exec_scope,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names));
param.scope));
CHECK(engine_); CHECK(engine_);
} }
......
...@@ -41,7 +41,7 @@ class DeviceProgram { ...@@ -41,7 +41,7 @@ class DeviceProgram {
const std::vector<std::vector<int64_t>>& origin_idims, const std::vector<std::vector<int64_t>>& origin_idims,
const std::string& model_cache_dir); const std::string& model_cache_dir);
bool BuildGraphAndCacheToFile( bool BuildGraphAndCacheToFile(
const std::vector<Instruction>& origin_program, RuntimeProgram* origin_program,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims, const std::vector<std::vector<int64_t>>& origin_idims,
...@@ -71,12 +71,16 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -71,12 +71,16 @@ class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext* ctx, SubgraphEngine(KernelContext* ctx,
int block_idx, int block_idx,
cpp::BlockDesc* block_desc, const std::shared_ptr<const cpp::ProgramDesc>& program_desc,
Scope* exec_scope,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names)
Scope* scope) : subgraph::Engine(ctx,
: subgraph::Engine( block_idx,
ctx, block_idx, block_desc, input_names, output_names, scope) {} program_desc,
exec_scope,
input_names,
output_names) {}
protected: protected:
bool PrepareWorkspaceForDeviceProgram() override; bool PrepareWorkspaceForDeviceProgram() override;
......
...@@ -152,7 +152,7 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -152,7 +152,7 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
cl::NDRange local_work_size_ = cl::NDRange{ cl::NDRange local_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
bool use_lws_{true}; bool use_lws_{true};
bool use_tune_{false}; bool use_tune_{true};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -155,6 +155,7 @@ TEST(nearest_interp_image2d, compute) { ...@@ -155,6 +155,7 @@ TEST(nearest_interp_image2d, compute) {
auto *x_data = x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL)); auto *x_data = x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data = y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL)); auto *y_data = y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data_ref = y_ref.mutable_data<float>(TARGET(kARM)); auto *y_data_ref = y_ref.mutable_data<float>(TARGET(kARM));
memset(reinterpret_cast<char *>(y_data_ref), 0, y_ref.numel());
auto *mapped_x = static_cast<float *>(TargetWrapperCL::Map( auto *mapped_x = static_cast<float *>(TargetWrapperCL::Map(
x_data, 0, sizeof(float) * x_dim.production())); x_data, 0, sizeof(float) * x_dim.production()));
auto *mapped_y = static_cast<float *>(TargetWrapperCL::Map( auto *mapped_y = static_cast<float *>(TargetWrapperCL::Map(
......
...@@ -28,26 +28,6 @@ namespace lite { ...@@ -28,26 +28,6 @@ namespace lite {
namespace kernels { namespace kernels {
namespace rknpu { namespace rknpu {
bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() {
// Obtain the origin input tensors, and create the origin output
// tensors(Don't try to access them before launch the device program or the
// origin program)
PrepareWorkspaceForOriginProgram();
// Create the device input and output tensors, but don't initialize them
// with the dimensions
device_itensors_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) {
device_itensors_[i].reset(new hiai::AiTensor);
CHECK(device_itensors_[i]);
}
device_otensors_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) {
device_otensors_[i].reset(new hiai::AiTensor);
CHECK(device_otensors_[i]);
}
return true;
}
bool SubgraphEngine::BuildDeviceProgram() { bool SubgraphEngine::BuildDeviceProgram() {
LOG(INFO) << "[RKNPU]:BuildDeviceProgram"; LOG(INFO) << "[RKNPU]:BuildDeviceProgram";
int status = 0; int status = 0;
...@@ -55,10 +35,11 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -55,10 +35,11 @@ bool SubgraphEngine::BuildDeviceProgram() {
// RKNPU IR graph // RKNPU IR graph
subgraph::rknpu::Graph graph; subgraph::rknpu::Graph graph;
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
for (auto& inst : origin_program_) { const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = const_cast<OpLite*>(inst.op()); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
...@@ -76,92 +57,26 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -76,92 +57,26 @@ bool SubgraphEngine::BuildDeviceProgram() {
} }
// Collect the valid input and output nodes in the RKNPU IR graph and update // Collect the valid input and output nodes in the RKNPU IR graph and update
// the input and output names // the input and output names
device_inames_.clear(); device_itensors_.clear();
device_onames_.clear(); device_otensors_.clear();
for (auto& input_name : input_names_) {
LOG(INFO) << "[RKNPU] Input node " << input_name;
if (graph.Has(input_name)) {
LOG(INFO) << input_name << " Precision "
<< PrecisionToStr(graph.Get(input_name)->precision());
device_itensors_.push_back(graph.Get(input_name)->data());
device_inames_.push_back(input_name);
} else {
LOG(WARNING) << "[RKNPU] Input node " << input_name
<< " is ignored because it does not exist.";
}
}
for (auto& output_name : output_names_) {
LOG(INFO) << "[RKNPU] Output node " << output_name;
if (graph.Has(output_name)) {
auto tensor = scope_->FindMutableTensor(output_name);
LOG(INFO) << output_name << " Precision "
<< PrecisionToStr(tensor->precision());
device_otensors_.push_back(graph.Get(output_name)->data());
device_onames_.push_back(output_name);
} else {
LOG(WARNING) << "[RKNPU] Output node " << output_name
<< " is ignored because it does not exist.";
}
}
CHECK(!device_inames_.empty())
<< "[RKNPU] No input nodes found for building NPU model";
CHECK(!device_onames_.empty())
<< "[RKNPU] No output nodes found for building NPU model";
device_program_ = lite::rknpu::Device::Global().Build(
model_name_, graph.GetHandle(), device_itensors_, device_otensors_);
if (device_program_ == nullptr) {
LOG(WARNING) << "[RKNPU] Build model failed!";
return false;
}
// input
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
for (size_t i = 0; i < input_names_.size(); i++) { for (size_t i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); CHECK(graph.Has(input_names_[i])) << "[RKNPU] Failed to find input node "
CHECK(origin_itensors_[i]); << input_names_[i];
origin_idims_[i] = origin_itensors_[i]->dims(); auto node = graph.Get(input_names_[i]);
}
// output
origin_odims_.resize(output_names_.size());
origin_otensors_.resize(output_names_.size());
for (size_t i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]);
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
auto output_dims = origin_otensors_[i]->dims();
}
origin_idims_.resize(device_inames_.size());
origin_itensors_.resize(device_inames_.size());
device_itensors_.resize(device_inames_.size());
origin_odims_.resize(device_onames_.size());
origin_otensors_.resize(device_onames_.size());
device_otensors_.resize(device_onames_.size());
for (int i = 0; i < device_inames_.size(); i++) {
auto node = graph.Get(device_inames_[i]);
auto precision = node->precision(); auto precision = node->precision();
auto layout = node->layout(); auto layout = node->layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]); LOG(INFO) << "[RKNPU] Inputs[" << i << "] name: " << input_names_[i]
CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
LOG(INFO) << "[RKNPU] Inputs[" << i << "] name: " << device_inames_[i]
<< " precision: " << PrecisionToStr(precision) << " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout); << " layout: " << DataLayoutToStr(layout);
device_itensors_.push_back(node->data());
} }
for (int i = 0; i < device_onames_.size(); i++) { for (size_t i = 0; i < output_names_.size(); i++) {
auto node = graph.Get(device_onames_[i]); CHECK(graph.Has(output_names_[i])) << "[RKNPU] Failed to find output node "
<< output_names_[i];
auto node = graph.Get(output_names_[i]);
auto precision = node->precision(); auto precision = node->precision();
auto layout = node->layout(); auto layout = node->layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]); LOG(INFO) << "[RKNPU] Outputs[" << i << "] name: " << output_names_[i]
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
LOG(INFO) << "[RKNPU] Outputs[" << i << "] name: " << device_onames_[i]
<< " precision: " << PrecisionToStr(precision) << " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout); << " layout: " << DataLayoutToStr(layout);
// Prepare the device output tensors // Prepare the device output tensors
...@@ -182,11 +97,19 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -182,11 +97,19 @@ bool SubgraphEngine::BuildDeviceProgram() {
origin_otensors_[i]->mutable_data<int64_t>(); origin_otensors_[i]->mutable_data<int64_t>();
break; break;
default: default:
LOG(FATAL) << "[RKNPU] " << device_onames_[i] LOG(FATAL) << "[RKNPU] " << output_names_[i]
<< " can't mutable data with precision type " << " can't mutable data with precision type "
<< PrecisionToStr(precision); << PrecisionToStr(precision);
break; break;
} }
device_otensors_.push_back(node->data());
}
// Create the RKNPU model and set the input and output nodes
device_program_ = lite::rknpu::Device::Global().Build(
model_name_, graph.GetHandle(), device_itensors_, device_otensors_);
if (device_program_ == nullptr) {
LOG(WARNING) << "[RKNPU] Build model failed!";
return false;
} }
return true; return true;
} }
...@@ -196,8 +119,8 @@ bool SubgraphEngine::LaunchDeviceProgram() { ...@@ -196,8 +119,8 @@ bool SubgraphEngine::LaunchDeviceProgram() {
std::vector<rk::nn::InputInfo> inputs; std::vector<rk::nn::InputInfo> inputs;
std::vector<rk::nn::OutputInfo> outputs; std::vector<rk::nn::OutputInfo> outputs;
inputs.resize(device_itensors_.size()); inputs.resize(origin_itensors_.size());
for (size_t i = 0; i < device_itensors_.size(); i++) { for (size_t i = 0; i < origin_itensors_.size(); i++) {
inputs[i].index = i; inputs[i].index = i;
inputs[i].buf = const_cast<void*>(origin_itensors_[i]->raw_data()); inputs[i].buf = const_cast<void*>(origin_itensors_[i]->raw_data());
inputs[i].size = origin_itensors_[i]->memory_size(); inputs[i].size = origin_itensors_[i]->memory_size();
...@@ -207,8 +130,8 @@ bool SubgraphEngine::LaunchDeviceProgram() { ...@@ -207,8 +130,8 @@ bool SubgraphEngine::LaunchDeviceProgram() {
inputs[i].layout = rk::nn::DataLayoutType::NCHW; inputs[i].layout = rk::nn::DataLayoutType::NCHW;
} }
outputs.resize(device_otensors_.size()); outputs.resize(origin_otensors_.size());
for (size_t i = 0; i < device_otensors_.size(); i++) { for (size_t i = 0; i < origin_otensors_.size(); i++) {
outputs[i].index = i; outputs[i].index = i;
outputs[i].buf = const_cast<void*>(origin_otensors_[i]->raw_data()); outputs[i].buf = const_cast<void*>(origin_otensors_[i]->raw_data());
outputs[i].size = origin_otensors_[i]->memory_size(); outputs[i].size = origin_otensors_[i]->memory_size();
...@@ -225,11 +148,11 @@ void SubgraphCompute::PrepareForRun() { ...@@ -225,11 +148,11 @@ void SubgraphCompute::PrepareForRun() {
LOG(INFO) << "[RKNPU]:PrepareForRun"; LOG(INFO) << "[RKNPU]:PrepareForRun";
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(), engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx, param.block_idx,
param.sub_block_desc, param.program_desc,
param.exec_scope,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names));
param.scope));
CHECK(engine_); CHECK(engine_);
} }
......
...@@ -34,15 +34,18 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -34,15 +34,18 @@ class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext *ctx, SubgraphEngine(KernelContext *ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, const std::shared_ptr<const cpp::ProgramDesc> &program_desc,
Scope *exec_scope,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names)
Scope *scope) : subgraph::Engine(ctx,
: subgraph::Engine( block_idx,
ctx, block_idx, block_desc, input_names, output_names, scope) {} program_desc,
exec_scope,
input_names,
output_names) {}
protected: protected:
bool PrepareWorkspaceForDeviceProgram() override;
bool BuildDeviceProgram() override; bool BuildDeviceProgram() override;
bool LaunchDeviceProgram() override; bool LaunchDeviceProgram() override;
......
...@@ -31,11 +31,14 @@ void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() { ...@@ -31,11 +31,14 @@ void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() {
CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */ CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */
table_lens_cpu_.push_back(table_dims[0]); table_lens_cpu_.push_back(table_dims[0]);
} }
void* lens_ptr = nullptr;
size_t lens_size = table_lens_cpu_.size() * sizeof(int); size_t lens_size = table_lens_cpu_.size() * sizeof(int);
xpu_malloc(&lens_ptr, lens_size); table_lens_guard_ =
xpu_memcpy(lens_ptr, &table_lens_cpu_[0], lens_size, XPU_HOST_TO_DEVICE); TargetWrapperXPU::MallocScratchPad(lens_size, false /* use_l3 */);
table_lens_guard_.reset(lens_ptr); XPU_CALL(xpu_memcpy(table_lens_guard_->addr_,
&table_lens_cpu_[0],
lens_size,
XPU_HOST_TO_DEVICE));
} }
void XPUEmbeddingWithEltwiseAddCompute::Run() { void XPUEmbeddingWithEltwiseAddCompute::Run() {
...@@ -55,16 +58,16 @@ void XPUEmbeddingWithEltwiseAddCompute::Run() { ...@@ -55,16 +58,16 @@ void XPUEmbeddingWithEltwiseAddCompute::Run() {
int embed_dim = table_dims[1]; int embed_dim = table_dims[1];
int emb_layer_num = param.Ids.size(); int emb_layer_num = param.Ids.size();
int r = xdnn::embedding_with_ewadd<float, int64_t, false, false>( int r = xdnn::embedding_with_ewadd<float, int64_t, false, false>(
ctx.GetRawContext(), /* context */ ctx.GetRawContext(), /* context */
embed_dim, /* embed_dim */ embed_dim, /* embed_dim */
idx_len, /* idx_len */ idx_len, /* idx_len */
emb_layer_num, /* emb_layer_num */ emb_layer_num, /* emb_layer_num */
param.padding_idx, /* padding_idx */ param.padding_idx, /* padding_idx */
&arg_tables_[0], /* tables */ &arg_tables_[0], /* tables */
&arg_ids_[0], /* indices */ &arg_ids_[0], /* indices */
static_cast<int*>(table_lens_guard_.get()), /* table_lens */ static_cast<int*>(table_lens_guard_->addr_), /* table_lens */
nullptr, /* scale_after_emb */ nullptr, /* scale_after_emb */
nullptr, /* scale_after_ewadd */ nullptr, /* scale_after_ewadd */
param.Out->mutable_data<float>(TARGET(kXPU)) /* top */); param.Out->mutable_data<float>(TARGET(kXPU)) /* top */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <vector> #include <vector>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -36,7 +35,7 @@ class XPUEmbeddingWithEltwiseAddCompute ...@@ -36,7 +35,7 @@ class XPUEmbeddingWithEltwiseAddCompute
private: private:
std::vector<const int64_t*> arg_ids_; std::vector<const int64_t*> arg_ids_;
std::vector<const float*> arg_tables_; std::vector<const float*> arg_tables_;
std::unique_ptr<void, XPUFreeDeleter> table_lens_guard_; XPUScratchPadGuard table_lens_guard_;
std::vector<int> table_lens_cpu_; std::vector<int> table_lens_cpu_;
}; };
......
...@@ -27,8 +27,8 @@ namespace { ...@@ -27,8 +27,8 @@ namespace {
void FillMax(float max, float* xpu_ptr) { void FillMax(float max, float* xpu_ptr) {
float maxs[4] = {max, 0.0f, 0.0f, 0.0f}; float maxs[4] = {max, 0.0f, 0.0f, 0.0f};
xpu_memcpy( XPU_CALL(xpu_memcpy(
xpu_ptr, maxs, 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE); xpu_ptr, maxs, 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE));
} }
void GrnnLayout(int batch, void GrnnLayout(int batch,
...@@ -156,8 +156,8 @@ class MMDNNIdInfo { ...@@ -156,8 +156,8 @@ class MMDNNIdInfo {
idx_sorted.data(), idx_sorted.data(),
idx_sorted.size() * sizeof(int)); idx_sorted.size() * sizeof(int));
offset += idx_sorted.size() * sizeof(int); offset += idx_sorted.size() * sizeof(int);
xpu_memcpy( XPU_CALL(xpu_memcpy(
l3_buffer_, cpu_buffer_, offset, XPUMemcpyKind::XPU_HOST_TO_DEVICE); l3_buffer_, cpu_buffer_, offset, XPUMemcpyKind::XPU_HOST_TO_DEVICE));
} }
}; };
...@@ -221,29 +221,32 @@ class MMDNNFcOp { ...@@ -221,29 +221,32 @@ class MMDNNFcOp {
int m, int m,
float* out, float* out,
const float* in_max_by_caller = nullptr) { const float* in_max_by_caller = nullptr) {
int r = 0;
if (in_max_by_caller == nullptr) { if (in_max_by_caller == nullptr) {
xdnn::findmax<float>(ctx, in, m * k_, in_max_); r = xdnn::findmax<float>(ctx, in, m * k_, in_max_);
CHECK_EQ(r, 0);
in_max_by_caller = in_max_; in_max_by_caller = in_max_;
} }
xdnn::gemm_int16_maxptr<float, int16_t, float>(ctx, r = xdnn::gemm_int16_maxptr<float, int16_t, float>(ctx,
false, false,
true, true,
m, m,
n_, n_,
k_, k_,
1.0f, 1.0f,
in, in,
k_, k_,
weight_, weight_,
k_, k_,
0.0f, 0.0f,
out, out,
n_, n_,
bias_, bias_,
act_type_, act_type_,
in_max_by_caller, in_max_by_caller,
weight_max_, weight_max_,
out_max); out_max);
CHECK_EQ(r, 0);
} }
}; };
...@@ -331,44 +334,49 @@ class MMDNNGrnnOp { ...@@ -331,44 +334,49 @@ class MMDNNGrnnOp {
gru_out = l3_buffer + 4 * slot_size; gru_out = l3_buffer + 4 * slot_size;
} }
xdnn::search_seq2batch(ctx, int r = 0;
batch, r = xdnn::search_seq2batch(ctx,
max_width, batch,
cap_e_, max_width,
sentense.idx_sorted_32, cap_e_,
sentense.lod_32, sentense.idx_sorted_32,
sentense.new_offset_32, sentense.lod_32,
in, sentense.new_offset_32,
seq2batch_out); in,
seq2batch_out);
xdnn::findmax<float>(ctx, in, cap_l * cap_e_, input_max_); CHECK_EQ(r, 0);
r = xdnn::findmax<float>(ctx, in, cap_l * cap_e_, input_max_);
CHECK_EQ(r, 0);
fc_e2h0_.Infer(ctx, seq2batch_out, cap_l, fc_e2h_out, input_max_); fc_e2h0_.Infer(ctx, seq2batch_out, cap_l, fc_e2h_out, input_max_);
fc_e2h1_.Infer( fc_e2h1_.Infer(
ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_, input_max_); ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_, input_max_);
fc_e2h2_.Infer( fc_e2h2_.Infer(
ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_ * 2, input_max_); ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_ * 2, input_max_);
xdnn::search_grnn<float, int16_t>(ctx, r = xdnn::search_grnn<float, int16_t>(ctx,
cap_l, cap_l,
cap_h_, cap_h_,
cap_e_, cap_e_,
max_width, max_width,
sentense.new_offset_32, sentense.new_offset_32,
fc_e2h_out, fc_e2h_out,
dense_h2h_, dense_h2h_,
gru_out, gru_out,
dense_h2h_max_[0], dense_h2h_max_[0],
dense_h2h_max_[1], dense_h2h_max_[1],
dense_h2h_max_[2]); dense_h2h_max_[2]);
CHECK_EQ(r, 0);
xdnn::search_batch2seq(ctx,
batch, r = xdnn::search_batch2seq(ctx,
max_width, batch,
cap_h_, max_width,
sentense.idx_sorted_32, cap_h_,
sentense.lod_32, sentense.idx_sorted_32,
sentense.new_offset_32, sentense.lod_32,
gru_out, sentense.new_offset_32,
out); gru_out,
out);
CHECK_EQ(r, 0);
} }
}; };
...@@ -435,38 +443,43 @@ class MMDNNAttentionOp { ...@@ -435,38 +443,43 @@ class MMDNNAttentionOp {
} }
seqfc_.Infer(ctx, input, cap_l, seqfc_out); seqfc_.Infer(ctx, input, cap_l, seqfc_out);
xdnn::search_noaligned_mat_mul(ctx, int r = 0;
0, r = xdnn::search_noaligned_mat_mul(ctx,
1, 0,
batch, 1,
lod_32, batch,
max_width, lod_32,
dim_, max_width,
alpha0_, dim_,
input, alpha0_,
seqfc_out, input,
batchgemm0_out); seqfc_out,
xdnn::search_seq_softmax( batchgemm0_out);
CHECK_EQ(r, 0);
r = xdnn::search_seq_softmax(
ctx, batchgemm0_out, seq_softmax_out, lod_32, batch, max_width); ctx, batchgemm0_out, seq_softmax_out, lod_32, batch, max_width);
xdnn::search_noaligned_mat_mul(ctx, CHECK_EQ(r, 0);
0, r = xdnn::search_noaligned_mat_mul(ctx,
0, 0,
batch, 0,
lod_32, batch,
max_width, lod_32,
dim_, max_width,
alpha1_, dim_,
seq_softmax_out, alpha1_,
input, seq_softmax_out,
batchgemm1_out); input,
xdnn::sequence_pooling_forward(ctx, batchgemm1_out);
xdnn::Pooling_t::MAX_WITHOUT_INDEX, CHECK_EQ(r, 0);
batch, r = xdnn::sequence_pooling_forward(ctx,
lod_32, xdnn::Pooling_t::MAX_WITHOUT_INDEX,
dim_, batch,
batchgemm1_out, lod_32,
nullptr, dim_,
pool_out); batchgemm1_out,
nullptr,
pool_out);
CHECK_EQ(r, 0);
} }
}; };
...@@ -510,12 +523,13 @@ class MMDNNMatchConvTopk { ...@@ -510,12 +523,13 @@ class MMDNNMatchConvTopk {
float conv_w_max, float conv_w_max,
int dim_t, int dim_t,
int dim_in, int dim_in,
int out_channel,
int upper_bound_batch, int upper_bound_batch,
int upper_bound_seqlen, int upper_bound_seqlen,
const std::vector<int>& topks) { const std::vector<int>& topks) {
dim_t_ = dim_t; dim_t_ = dim_t;
dim_in_ = dim_in; dim_in_ = dim_in;
out_channel_ = 5; // TODO(miaotianxiang): out_channel_ = out_channel;
topks_ = topks; topks_ = topks;
xw_fc_.Init(input_w, xw_fc_.Init(input_w,
...@@ -553,10 +567,10 @@ class MMDNNMatchConvTopk { ...@@ -553,10 +567,10 @@ class MMDNNMatchConvTopk {
topks_xpu_guard_ = topks_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(topks_.size() * sizeof(int), false); TargetWrapperXPU::MallocScratchPad(topks_.size() * sizeof(int), false);
topks_xpu_ = reinterpret_cast<int*>(topks_xpu_guard_->addr_); topks_xpu_ = reinterpret_cast<int*>(topks_xpu_guard_->addr_);
xpu_memcpy(topks_xpu_, XPU_CALL(xpu_memcpy(topks_xpu_,
topks_.data(), topks_.data(),
topks_.size() * sizeof(int), topks_.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
useless_topk_pos_guard_ = useless_topk_pos_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(int), false); TargetWrapperXPU::MallocScratchPad(4 * sizeof(int), false);
useless_topk_pos_ = reinterpret_cast<int*>(useless_topk_pos_guard_->addr_); useless_topk_pos_ = reinterpret_cast<int*>(useless_topk_pos_guard_->addr_);
...@@ -576,18 +590,18 @@ class MMDNNMatchConvTopk { ...@@ -576,18 +590,18 @@ class MMDNNMatchConvTopk {
for (auto e : left_lod) { for (auto e : left_lod) {
left_lod_32_cpu.push_back(e); left_lod_32_cpu.push_back(e);
} }
xpu_memcpy(left_lod_32_, XPU_CALL(xpu_memcpy(left_lod_32_,
left_lod_32_cpu.data(), left_lod_32_cpu.data(),
left_lod_32_cpu.size() * sizeof(int), left_lod_32_cpu.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
std::vector<int> right_lod_32_cpu; std::vector<int> right_lod_32_cpu;
for (auto e : right_lod) { for (auto e : right_lod) {
right_lod_32_cpu.push_back(e); right_lod_32_cpu.push_back(e);
} }
xpu_memcpy(right_lod_32_, XPU_CALL(xpu_memcpy(right_lod_32_,
right_lod_32_cpu.data(), right_lod_32_cpu.data(),
right_lod_32_cpu.size() * sizeof(int), right_lod_32_cpu.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
std::vector<int> lod_match = {0}; std::vector<int> lod_match = {0};
std::vector<int> lod_conv = {0}; std::vector<int> lod_conv = {0};
...@@ -611,18 +625,18 @@ class MMDNNMatchConvTopk { ...@@ -611,18 +625,18 @@ class MMDNNMatchConvTopk {
left_seqlen_sum += len_x; left_seqlen_sum += len_x;
right_seqlen_sum += len_y; right_seqlen_sum += len_y;
} }
xpu_memcpy(match_lod_32_, XPU_CALL(xpu_memcpy(match_lod_32_,
lod_match.data(), lod_match.data(),
lod_match.size() * sizeof(int), lod_match.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(conv_lod_32_, XPU_CALL(xpu_memcpy(conv_lod_32_,
lod_conv.data(), lod_conv.data(),
lod_conv.size() * sizeof(int), lod_conv.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(topk_offset_32_, XPU_CALL(xpu_memcpy(topk_offset_32_,
lod_topk.data(), lod_topk.data(),
lod_topk.size() * sizeof(int), lod_topk.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
float* xwy_out = hbm_buffer_; float* xwy_out = hbm_buffer_;
float* conv_out = hbm_buffer_ + x_mul_y_sum * dim_t_; float* conv_out = hbm_buffer_ + x_mul_y_sum * dim_t_;
...@@ -640,19 +654,21 @@ class MMDNNMatchConvTopk { ...@@ -640,19 +654,21 @@ class MMDNNMatchConvTopk {
int max_width = std::max(left_seqlen_max, right_seqlen_max); int max_width = std::max(left_seqlen_max, right_seqlen_max);
xw_fc_.Infer(ctx, left->data<float>(), left_seqlen_sum, xw_out); xw_fc_.Infer(ctx, left->data<float>(), left_seqlen_sum, xw_out);
xdnn::match_matrix_tensor(ctx, int r = 0;
batch, r = xdnn::match_matrix_tensor(ctx,
xw_out, batch,
right->data<float>(), xw_out,
left_lod_32_, right->data<float>(),
right_lod_32_, left_lod_32_,
dim_t_, right_lod_32_,
dim_in_, dim_t_,
xwy_out, dim_in_,
xw_fc_.out_max, xwy_out,
xdnn::Activation_t::RELU, xw_fc_.out_max,
max_width); xdnn::Activation_t::RELU,
xdnn::search_varconv<float, int16_t>( max_width);
CHECK_EQ(r, 0);
r = xdnn::search_varconv<float, int16_t>(
ctx, ctx,
batch, batch,
dim_t_, dim_t_,
...@@ -668,24 +684,27 @@ class MMDNNMatchConvTopk { ...@@ -668,24 +684,27 @@ class MMDNNMatchConvTopk {
conv_out, conv_out,
conv_weight_max_, conv_weight_max_,
xdnn::Activation_t::RELU); // TODO(miaotianxiang): xdnn::Activation_t::RELU); // TODO(miaotianxiang):
xdnn::sequence_concat(ctx, CHECK_EQ(r, 0);
xwy_out, r = xdnn::sequence_concat(ctx,
match_lod_32_, xwy_out,
conv_out, match_lod_32_,
conv_lod_32_, conv_out,
seq_concat_out, conv_lod_32_,
batch); seq_concat_out,
xdnn::sequence_topk_avg_pooling(ctx, batch);
seq_concat_out, CHECK_EQ(r, 0);
seq_avg_topk_out, r = xdnn::sequence_topk_avg_pooling(ctx,
useless_topk_pos_, seq_concat_out,
batch, seq_avg_topk_out,
dim_t_ + out_channel_, useless_topk_pos_,
topk_offset_32_, batch,
left_lod_32_, dim_t_ + out_channel_,
right_lod_32_, topk_offset_32_,
topks_xpu_, left_lod_32_,
topks_.size()); right_lod_32_,
topks_xpu_,
topks_.size());
CHECK_EQ(r, 0);
} }
}; };
...@@ -802,34 +821,38 @@ class MMDNNBidEmbGrnnAtt { ...@@ -802,34 +821,38 @@ class MMDNNBidEmbGrnnAtt {
pool_rv = grnn_rv_pool_out->mutable_data<float>(TARGET(kXPU)); pool_rv = grnn_rv_pool_out->mutable_data<float>(TARGET(kXPU));
att_out = att_pool_out->mutable_data<float>(TARGET(kXPU)); att_out = att_pool_out->mutable_data<float>(TARGET(kXPU));
xdnn::search_bid_emb_ew(ctx, int r = 0;
batch, r = xdnn::search_bid_emb_ew(ctx,
sentense.lod_64, batch,
sentense.id0_64, sentense.lod_64,
sentense.id1_64, sentense.id0_64,
table_, sentense.id1_64,
table_len_, table_,
emb_dim_, table_len_,
emb_fw, emb_dim_,
emb_rv, emb_fw,
table_len_ - 2, emb_rv,
1); table_len_ - 2,
1);
CHECK_EQ(r, 0);
bi_rv_.Infer(ctx, bi_rv_.Infer(ctx,
sentense, sentense,
emb_rv, emb_rv,
grnn_rv, grnn_rv,
l3_buffer + 2 * slot_len, l3_buffer + 2 * slot_len,
l3_size - 2 * slot_len * sizeof(float)); l3_size - 2 * slot_len * sizeof(float));
xdnn::sequence_reverse( r = xdnn::sequence_reverse(
ctx, batch, sentense.lod_32, cap_h_, grnn_rv, grnn_rv_rv); ctx, batch, sentense.lod_32, cap_h_, grnn_rv, grnn_rv_rv);
xdnn::sequence_pooling_forward(ctx, CHECK_EQ(r, 0);
xdnn::Pooling_t::LAST, r = xdnn::sequence_pooling_forward(ctx,
batch, xdnn::Pooling_t::LAST,
sentense.lod_32, batch,
cap_h_, sentense.lod_32,
grnn_rv, cap_h_,
nullptr, grnn_rv,
pool_rv); nullptr,
pool_rv);
CHECK_EQ(r, 0);
bi_fw_.Infer(ctx, bi_fw_.Infer(ctx,
sentense, sentense,
...@@ -837,19 +860,23 @@ class MMDNNBidEmbGrnnAtt { ...@@ -837,19 +860,23 @@ class MMDNNBidEmbGrnnAtt {
grnn_fw, grnn_fw,
l3_buffer + 2 * slot_len, l3_buffer + 2 * slot_len,
l3_size - 2 * slot_len * sizeof(float)); l3_size - 2 * slot_len * sizeof(float));
xdnn::sequence_pooling_forward(ctx, r = xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::LAST, xdnn::Pooling_t::LAST,
batch, batch,
sentense.lod_32, sentense.lod_32,
cap_h_, cap_h_,
grnn_fw, grnn_fw,
nullptr, nullptr,
pool_fw); pool_fw);
CHECK_EQ(r, 0);
const int concat_widths[] = {cap_h_, cap_h_, cap_h_}; const int concat_widths[] = {cap_h_, cap_h_, cap_h_};
const float* concat_ptrs[] = {emb_fw, grnn_fw, grnn_rv_rv}; const float* concat_ptrs[] = {emb_fw, grnn_fw, grnn_rv_rv};
xdnn::concat<float>( r = xdnn::concat<float>(
ctx, cap_l, concat_widths + 1, 2, concat_ptrs + 1, concat_2in); ctx, cap_l, concat_widths + 1, 2, concat_ptrs + 1, concat_2in);
xdnn::concat<float>(ctx, cap_l, concat_widths, 3, concat_ptrs, concat_3in); CHECK_EQ(r, 0);
r = xdnn::concat<float>(
ctx, cap_l, concat_widths, 3, concat_ptrs, concat_3in);
CHECK_EQ(r, 0);
att_.Infer(ctx, att_.Infer(ctx,
sentense, sentense,
concat_2in, concat_2in,
...@@ -899,16 +926,18 @@ class MMDNNEmbAtt { ...@@ -899,16 +926,18 @@ class MMDNNEmbAtt {
int cap_l = sentense.lod.back(); int cap_l = sentense.lod.back();
const float* emb_tables[] = {table_, table_}; const float* emb_tables[] = {table_, table_};
const int64_t* emb_indices[] = {sentense.id0_64, sentense.id1_64}; const int64_t* emb_indices[] = {sentense.id0_64, sentense.id1_64};
xdnn::embedding_with_ewadd<float, int64_t, false, false>(ctx, int r =
emb_dim_, xdnn::embedding_with_ewadd<float, int64_t, false, false>(ctx,
cap_l, emb_dim_,
2, cap_l,
table_len_ - 2, 2,
emb_tables, table_len_ - 2,
emb_indices, emb_tables,
nullptr, emb_indices,
nullptr, nullptr,
emb_fw); nullptr,
emb_fw);
CHECK_EQ(r, 0);
att_.Infer(ctx, sentense, emb_fw, att_out, l3_buffer, l3_size); att_.Infer(ctx, sentense, emb_fw, att_out, l3_buffer, l3_size);
} }
}; };
...@@ -990,7 +1019,7 @@ class MMDNNMergeAll { ...@@ -990,7 +1019,7 @@ class MMDNNMergeAll {
fc2_.Init( fc2_.Init(
fc2_w, fc2_w_max, fc2_b, fc2_n_, fc2_k_, xdnn::Activation_t::LINEAR); fc2_w, fc2_w_max, fc2_b, fc2_n_, fc2_k_, xdnn::Activation_t::LINEAR);
int hbm_total_len = max_cap_l * cap_h_ * 4 + int hbm_total_len = max_cap_l * cap_e_ * 2 + max_cap_l * cap_h_ * 2 +
upper_bound_batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + upper_bound_batch * (2 * cap_h_ + fc0_k_ + fc0_n_ +
fc1_k_ + fc1_n_ + fc2_n_); fc1_k_ + fc1_n_ + fc2_n_);
hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(
...@@ -1000,7 +1029,7 @@ class MMDNNMergeAll { ...@@ -1000,7 +1029,7 @@ class MMDNNMergeAll {
void Infer(xdnn::Context* ctx, void Infer(xdnn::Context* ctx,
const MMDNNIdInfo& sentense, const MMDNNIdInfo& sentense,
const std::vector<lite::Tensor*> concat_2in1_x, const std::vector<lite::Tensor*> concat_topk_x,
const std::vector<lite::Tensor*> concat_7in1_x, const std::vector<lite::Tensor*> concat_7in1_x,
lite::Tensor* out, lite::Tensor* out,
float* l3_buffer = nullptr, float* l3_buffer = nullptr,
...@@ -1010,13 +1039,13 @@ class MMDNNMergeAll { ...@@ -1010,13 +1039,13 @@ class MMDNNMergeAll {
float* topk_concat_out_fw = hbm_buffer_; float* topk_concat_out_fw = hbm_buffer_;
int hbm_total_len = int hbm_total_len =
cap_l * cap_h_ * 4 + cap_l * cap_e_ * 2 + cap_l * cap_h_ * 2 +
batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + fc1_k_ + fc1_n_ + fc2_n_); batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + fc1_k_ + fc1_n_ + fc2_n_);
if (l3_size > 0 && l3_size >= hbm_total_len * sizeof(float)) { if (l3_size > 0 && l3_size >= hbm_total_len * sizeof(float)) {
topk_concat_out_fw = l3_buffer; topk_concat_out_fw = l3_buffer;
} }
float* topk_concat_out_rv = topk_concat_out_fw + cap_l * cap_h_; float* topk_concat_out_rv = topk_concat_out_fw + cap_l * cap_e_;
float* grnn_fw = topk_concat_out_rv + cap_l * cap_h_; float* grnn_fw = topk_concat_out_rv + cap_l * cap_e_;
float* grnn_rv = grnn_fw + cap_l * cap_h_; float* grnn_rv = grnn_fw + cap_l * cap_h_;
float* pool_fw = grnn_rv + cap_l * cap_h_; float* pool_fw = grnn_rv + cap_l * cap_h_;
float* pool_rv = pool_fw + batch * cap_h_; float* pool_rv = pool_fw + batch * cap_h_;
...@@ -1027,18 +1056,27 @@ class MMDNNMergeAll { ...@@ -1027,18 +1056,27 @@ class MMDNNMergeAll {
// float* fc2_out = fc1_out + batch * fc1_n_; // float* fc2_out = fc1_out + batch * fc1_n_;
float* fc2_out = out->mutable_data<float>(TARGET(kXPU)); float* fc2_out = out->mutable_data<float>(TARGET(kXPU));
const int concat_widths[] = {static_cast<int>(concat_2in1_x[0]->dims()[1]), std::vector<int> concat_widths;
static_cast<int>(concat_2in1_x[1]->dims()[1])}; std::vector<const float*> concat_ptrs;
const float* concat_ptrs[] = {concat_2in1_x[0]->data<float>(), for (const auto* t : concat_topk_x) {
concat_2in1_x[1]->data<float>()}; concat_widths.push_back(static_cast<int>(t->dims()[1]));
xdnn::concat<float>( concat_ptrs.push_back(t->data<float>());
ctx, cap_l, concat_widths, 2, concat_ptrs, topk_concat_out_fw); }
xdnn::sequence_reverse(ctx, int r = 0;
batch, r = xdnn::concat<float>(ctx,
sentense.lod_32, cap_l,
cap_e_, concat_widths.data(),
topk_concat_out_fw, concat_widths.size(),
topk_concat_out_rv); concat_ptrs.data(),
topk_concat_out_fw);
CHECK_EQ(r, 0);
r = xdnn::sequence_reverse(ctx,
batch,
sentense.lod_32,
cap_e_,
topk_concat_out_fw,
topk_concat_out_rv);
CHECK_EQ(r, 0);
coverage_fw_.Infer(ctx, coverage_fw_.Infer(ctx,
sentense, sentense,
topk_concat_out_fw, topk_concat_out_fw,
...@@ -1051,22 +1089,24 @@ class MMDNNMergeAll { ...@@ -1051,22 +1089,24 @@ class MMDNNMergeAll {
grnn_rv, grnn_rv,
l3_buffer + hbm_total_len, l3_buffer + hbm_total_len,
l3_size - hbm_total_len * sizeof(float)); l3_size - hbm_total_len * sizeof(float));
xdnn::sequence_pooling_forward(ctx, r = xdnn::sequence_pooling_forward(ctx,
xdnn::Pooling_t::LAST, xdnn::Pooling_t::LAST,
batch, batch,
sentense.lod_32, sentense.lod_32,
cap_h_, cap_h_,
grnn_fw, grnn_fw,
nullptr, nullptr,
pool_fw); pool_fw);
xdnn::sequence_pooling_forward(ctx, CHECK_EQ(r, 0);
xdnn::Pooling_t::LAST, r = xdnn::sequence_pooling_forward(ctx,
batch, xdnn::Pooling_t::LAST,
sentense.lod_32, batch,
cap_h_, sentense.lod_32,
grnn_rv, cap_h_,
nullptr, grnn_rv,
pool_rv); nullptr,
pool_rv);
CHECK_EQ(r, 0);
const int concat_widths_fc0[] = { const int concat_widths_fc0[] = {
static_cast<int>(concat_7in1_x[0]->dims()[1]), static_cast<int>(concat_7in1_x[0]->dims()[1]),
...@@ -1089,11 +1129,13 @@ class MMDNNMergeAll { ...@@ -1089,11 +1129,13 @@ class MMDNNMergeAll {
const int concat_widths_fc1[] = {cap_h_, cap_h_, fc0_n_}; const int concat_widths_fc1[] = {cap_h_, cap_h_, fc0_n_};
const float* concat_ptrs_fc1[] = {pool_fw, pool_rv, fc0_out}; const float* concat_ptrs_fc1[] = {pool_fw, pool_rv, fc0_out};
xdnn::concat<float>( r = xdnn::concat<float>(
ctx, batch, concat_widths_fc0, 7, concat_ptrs_fc0, fc0_in); ctx, batch, concat_widths_fc0, 7, concat_ptrs_fc0, fc0_in);
CHECK_EQ(r, 0);
fc0_.Infer(ctx, fc0_in, batch, fc0_out); fc0_.Infer(ctx, fc0_in, batch, fc0_out);
xdnn::concat<float>( r = xdnn::concat<float>(
ctx, batch, concat_widths_fc1, 3, concat_ptrs_fc1, fc1_in); ctx, batch, concat_widths_fc1, 3, concat_ptrs_fc1, fc1_in);
CHECK_EQ(r, 0);
fc1_.Infer(ctx, fc1_in, batch, fc1_out); fc1_.Infer(ctx, fc1_in, batch, fc1_out);
fc2_.Infer(ctx, fc1_out, batch, fc2_out); fc2_.Infer(ctx, fc1_out, batch, fc2_out);
} }
...@@ -1111,14 +1153,12 @@ class XPUMmdnnBidEmbGrnnAttCompute ...@@ -1111,14 +1153,12 @@ class XPUMmdnnBidEmbGrnnAttCompute
private: private:
MMDNNIdInfo id_; MMDNNIdInfo id_;
MMDNNBidEmbGrnnAtt compound_; MMDNNBidEmbGrnnAtt compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
}; };
void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() { void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_); id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN);
compound_.Init(param.emb_tbl, compound_.Init(param.emb_tbl,
param.grnn_fw_wh, param.grnn_fw_wh,
param.grnn_fw_wh_maxs, param.grnn_fw_wh_maxs,
...@@ -1131,8 +1171,8 @@ void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() { ...@@ -1131,8 +1171,8 @@ void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() {
param.att_fc_w, param.att_fc_w,
param.att_fc_w_max, param.att_fc_w_max,
param.att_fc_b, param.att_fc_b,
upper_bound_batch_, XPU_MAX_LOD_SIZE,
upper_bound_seqlen_); XPU_MAX_LOD_SEQ_LEN);
} }
void XPUMmdnnBidEmbGrnnAttCompute::Run() { void XPUMmdnnBidEmbGrnnAttCompute::Run() {
...@@ -1157,6 +1197,76 @@ void XPUMmdnnBidEmbGrnnAttCompute::Run() { ...@@ -1157,6 +1197,76 @@ void XPUMmdnnBidEmbGrnnAttCompute::Run() {
xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size); xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size);
} }
class XPUMmdnnBidEmbGrnnAttCompute2
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUMmdnnBidEmbGrnnAttParam2;
void PrepareForRun() override;
void Run() override;
private:
MMDNNIdInfo id_;
MMDNNBidEmbGrnnAtt compound_;
};
void XPUMmdnnBidEmbGrnnAttCompute2::PrepareForRun() {
auto& param = this->Param<param_t>();
id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN);
compound_.Init(param.emb_tbl,
param.grnn_fw_wh,
param.grnn_fw_wh_maxs,
param.grnn_fw_wi,
param.grnn_fw_wi_maxs,
param.grnn_rv_wh,
param.grnn_rv_wh_maxs,
param.grnn_rv_wi,
param.grnn_rv_wi_maxs,
param.att_fc_w,
param.att_fc_w_max,
param.att_fc_b,
XPU_MAX_LOD_SIZE,
XPU_MAX_LOD_SEQ_LEN);
}
void XPUMmdnnBidEmbGrnnAttCompute2::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto* xpu_ctx = ctx.GetRawContext();
int batch = param.id0->lod()[0].size() - 1;
id_.Update(param.id0, param.id1);
compound_.Infer(ctx.GetRawContext(),
batch,
id_,
param.grnn_fw_pool_out,
param.grnn_rv_pool_out,
param.att_pool_out,
param.concat_3in1_out,
param.emb_fw_out,
reinterpret_cast<float*>(
reinterpret_cast<char*>(xpu_ctx->workspace_l3_ptr) +
xpu_ctx->used_l3_size),
xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size);
int num = param.id0->numel();
int embed_dim = param.emb_tbl->dims()[1];
// TODO(miaotianxiang):
int r = xdnn::embedding<float, int64_t>(
ctx.GetRawContext(), /* context */
num, /* num */
param.id0->data<int64_t>(), /* indices */
embed_dim, /* embed_dim */
param.emb_tbl->data<float>(), /* table */
param.emb0_out->mutable_data<float>(TARGET(kXPU)), /* top */
128000 /* padding_idx */);
CHECK_EQ(r, 0);
}
class XPUMmdnnBidEmbAttCompute class XPUMmdnnBidEmbAttCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> { : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public: public:
...@@ -1169,20 +1279,18 @@ class XPUMmdnnBidEmbAttCompute ...@@ -1169,20 +1279,18 @@ class XPUMmdnnBidEmbAttCompute
private: private:
MMDNNIdInfo id_; MMDNNIdInfo id_;
MMDNNEmbAtt compound_; MMDNNEmbAtt compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
}; };
void XPUMmdnnBidEmbAttCompute::PrepareForRun() { void XPUMmdnnBidEmbAttCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_); id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN);
compound_.Init(param.emb_tbl, compound_.Init(param.emb_tbl,
param.att_fc_w, param.att_fc_w,
param.att_fc_w_max, param.att_fc_w_max,
param.att_fc_b, param.att_fc_b,
upper_bound_batch_, XPU_MAX_LOD_SIZE,
upper_bound_seqlen_); XPU_MAX_LOD_SEQ_LEN);
} }
void XPUMmdnnBidEmbAttCompute::Run() { void XPUMmdnnBidEmbAttCompute::Run() {
...@@ -1215,8 +1323,6 @@ class XPUMmdnnMatchConvTopkCompute ...@@ -1215,8 +1323,6 @@ class XPUMmdnnMatchConvTopkCompute
private: private:
MMDNNMatchConvTopk compound_; MMDNNMatchConvTopk compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
}; };
void XPUMmdnnMatchConvTopkCompute::PrepareForRun() { void XPUMmdnnMatchConvTopkCompute::PrepareForRun() {
...@@ -1228,8 +1334,9 @@ void XPUMmdnnMatchConvTopkCompute::PrepareForRun() { ...@@ -1228,8 +1334,9 @@ void XPUMmdnnMatchConvTopkCompute::PrepareForRun() {
param.conv_w_max, param.conv_w_max,
param.dim_t, param.dim_t,
param.input_w->dims()[0], param.input_w->dims()[0],
upper_bound_batch_, param.output_channel,
upper_bound_seqlen_, XPU_MAX_LOD_SIZE,
XPU_MAX_LOD_SEQ_LEN,
param.topks); param.topks);
} }
...@@ -1261,14 +1368,12 @@ class XPUMmdnnMergeAllCompute ...@@ -1261,14 +1368,12 @@ class XPUMmdnnMergeAllCompute
private: private:
MMDNNIdInfo id_; MMDNNIdInfo id_;
MMDNNMergeAll compound_; MMDNNMergeAll compound_;
int upper_bound_batch_ = 40;
int upper_bound_seqlen_ = 512;
}; };
void XPUMmdnnMergeAllCompute::PrepareForRun() { void XPUMmdnnMergeAllCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
id_.Init(upper_bound_batch_, upper_bound_seqlen_); id_.Init(XPU_MAX_LOD_SIZE, XPU_MAX_LOD_SEQ_LEN);
compound_.Init(param.grnn_fw_wh, compound_.Init(param.grnn_fw_wh,
param.grnn_fw_wh_maxs, param.grnn_fw_wh_maxs,
param.grnn_fw_wi, param.grnn_fw_wi,
...@@ -1286,8 +1391,8 @@ void XPUMmdnnMergeAllCompute::PrepareForRun() { ...@@ -1286,8 +1391,8 @@ void XPUMmdnnMergeAllCompute::PrepareForRun() {
param.fc2_w, param.fc2_w,
param.fc2_w_max, param.fc2_w_max,
param.fc2_b, param.fc2_b,
upper_bound_batch_, XPU_MAX_LOD_SIZE,
upper_bound_seqlen_); XPU_MAX_LOD_SEQ_LEN);
} }
void XPUMmdnnMergeAllCompute::Run() { void XPUMmdnnMergeAllCompute::Run() {
...@@ -1296,10 +1401,10 @@ void XPUMmdnnMergeAllCompute::Run() { ...@@ -1296,10 +1401,10 @@ void XPUMmdnnMergeAllCompute::Run() {
auto* xpu_ctx = ctx.GetRawContext(); auto* xpu_ctx = ctx.GetRawContext();
id_.Update(param.concat_2in1_x[0], param.concat_2in1_x[1]); id_.Update(param.concat_topk_x[0], param.concat_topk_x[1]);
compound_.Infer(ctx.GetRawContext(), compound_.Infer(ctx.GetRawContext(),
id_, id_,
param.concat_2in1_x, param.concat_topk_x,
param.concat_7in1_x, param.concat_7in1_x,
param.out, param.out,
reinterpret_cast<float*>( reinterpret_cast<float*>(
...@@ -1335,6 +1440,29 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att, ...@@ -1335,6 +1440,29 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att,
.BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUMmdnnBidEmbGrnnAttCompute2,
def)
.BindInput("id0", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("id1", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("emb_tbl", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("att_fc_w", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("att_fc_b", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("emb0_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("grnn_fw_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("grnn_rv_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("att_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("concat_3in1_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_att, REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_att,
kXPU, kXPU,
kFloat, kFloat,
...@@ -1371,7 +1499,7 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_merge_all, ...@@ -1371,7 +1499,7 @@ REGISTER_LITE_KERNEL(__xpu__mmdnn_merge_all,
paddle::lite::kernels::xpu::XPUMmdnnMergeAllCompute, paddle::lite::kernels::xpu::XPUMmdnnMergeAllCompute,
def) def)
.BindInput("concat_7in1_x", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("concat_7in1_x", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("concat_2in1_x", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("concat_topk_x", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))})
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector> #include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector> #include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
......
...@@ -22,16 +22,19 @@ namespace kernels { ...@@ -22,16 +22,19 @@ namespace kernels {
namespace xpu { namespace xpu {
void XPUMmdnnSearchAttentionCompute::PrepareForRun() { void XPUMmdnnSearchAttentionCompute::PrepareForRun() {
offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
w_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float)); pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
w_max_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(8 * sizeof(float), false /* use_l3 */);
buffer_at_l3_guard_ = TargetWrapperXPU::MallocScratchPad( buffer_at_l3_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * L3_SLOT_SIZE * sizeof(float), false /* use_l3 */); 5 * L3_SLOT_SIZE * sizeof(float), false /* use_l3 */);
buffer_at_gm_guard_ = TargetWrapperXPU::MallocScratchPad( buffer_at_gm_guard_ = TargetWrapperXPU::MallocScratchPad(
5 * GM_SLOT_SIZE * sizeof(float), false /* use_l3 */); 5 * GM_SLOT_SIZE * sizeof(float), false /* use_l3 */);
offset_cpu.reset(new int[64]); offset_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
pad_begin_cpu.reset(new int[64]); pad_begin_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void XPUMmdnnSearchAttentionCompute::Run() { void XPUMmdnnSearchAttentionCompute::Run() {
...@@ -72,18 +75,18 @@ void XPUMmdnnSearchAttentionCompute::Run() { ...@@ -72,18 +75,18 @@ void XPUMmdnnSearchAttentionCompute::Run() {
} }
offset_cpu[batch] = offset[batch]; offset_cpu[batch] = offset[batch];
xpu_memcpy(offset_xpu_guard_->addr_, XPU_CALL(xpu_memcpy(offset_xpu_guard_->addr_,
offset_cpu.get(), offset_cpu.get(),
offset.size() * sizeof(int), offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(pad_begin_xpu_guard_->addr_, XPU_CALL(xpu_memcpy(pad_begin_xpu_guard_->addr_,
pad_begin_cpu.get(), pad_begin_cpu.get(),
batch * sizeof(int), batch * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(w_max_xpu_guard_->addr_, XPU_CALL(xpu_memcpy(w_max_xpu_guard_->addr_,
maxs_cpu, maxs_cpu,
8 * sizeof(float), 8 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int* offset_xpu = reinterpret_cast<int*>(offset_xpu_guard_->addr_); int* offset_xpu = reinterpret_cast<int*>(offset_xpu_guard_->addr_);
int* pad_begin_xpu = reinterpret_cast<int*>(pad_begin_xpu_guard_->addr_); int* pad_begin_xpu = reinterpret_cast<int*>(pad_begin_xpu_guard_->addr_);
...@@ -115,90 +118,99 @@ void XPUMmdnnSearchAttentionCompute::Run() { ...@@ -115,90 +118,99 @@ void XPUMmdnnSearchAttentionCompute::Run() {
} }
const auto* bottom_data = X->data<float>(); const auto* bottom_data = X->data<float>();
xdnn::search_sequence_pad_depad(ctx.GetRawContext(), int r = 0;
const_cast<float*>(bottom_data), r = xdnn::search_sequence_pad_depad(ctx.GetRawContext(),
group_padding_output, const_cast<float*>(bottom_data),
offset_xpu, group_padding_output,
max_seq, offset_xpu,
batch, max_seq,
dim1, batch,
0); // is_depad = 0 dim1,
0); // is_depad = 0
CHECK_EQ(r, 0);
// do-findmax // do-findmax
xdnn::findmax<float>(ctx.GetRawContext(), r = xdnn::findmax<float>(ctx.GetRawContext(),
group_padding_output, group_padding_output,
batch * max_seq * dim1, batch * max_seq * dim1,
maxs_xpu); maxs_xpu);
xdnn::gemm_int16_maxptr<float, int16_t, float>( CHECK_EQ(r, 0);
ctx.GetRawContext(), r = xdnn::gemm_int16_maxptr<float, int16_t, float>(
false, ctx.GetRawContext(), /* ctx */
true, // trans_a, trans_b false, /* trans_a */
batch * max_seq, true, /* trans_b */
dim1, batch * max_seq, /* m */
dim1, // m, n, k dim1, /* n */
1.0f, dim1, /* k */
group_padding_output, 1.0f, /* alpha */
dim1, // alpha, data_a, lda group_padding_output, /* data_a */
w_data, dim1, /* lda */
dim1, w_data, /* data_b */
0.0f, // data_b, ldb, beta dim1, /* ldb */
seq_fc_output, 0.0f, /* beta */
dim1, seq_fc_output, /* data_c */
b_data, // data_c, ldc, bias dim1, /* ldc */
xdnn::Activation_t::LINEAR, b_data, /* bias */
maxs_xpu, xdnn::Activation_t::LINEAR, /* act */
maxs_xpu + 4, maxs_xpu, /* max_a */
nullptr); // max_a, max_b, max_c maxs_xpu + 4, /* max_b */
xdnn::search_aligned_mat_mul(ctx.GetRawContext(), nullptr /* max_c */);
0, CHECK_EQ(r, 0);
1, r = xdnn::search_aligned_mat_mul(ctx.GetRawContext(),
batch, 0,
max_seq, 1,
max_seq, batch,
dim1, max_seq,
alpha0, max_seq,
group_padding_output, dim1,
dim1, alpha0,
seq_fc_output, group_padding_output,
dim1, dim1,
batchgemm0_output, seq_fc_output,
max_seq); dim1,
xdnn::search_pad_mask(ctx.GetRawContext(), batchgemm0_output,
batchgemm0_output, max_seq);
attention_output, CHECK_EQ(r, 0);
pad_begin_xpu, r = xdnn::search_pad_mask(ctx.GetRawContext(),
batch, batchgemm0_output,
max_seq, attention_output,
max_seq, pad_begin_xpu,
batch, batch,
mask); max_seq,
xdnn::softmax2d_forward(ctx.GetRawContext(), max_seq,
attention_output, batch,
seq_softmax_output, mask);
batch * max_seq, CHECK_EQ(r, 0);
max_seq, r = xdnn::softmax2d_forward(ctx.GetRawContext(),
true); attention_output,
xdnn::search_aligned_mat_mul(ctx.GetRawContext(), seq_softmax_output,
0, batch * max_seq,
0, max_seq,
batch, true);
max_seq, CHECK_EQ(r, 0);
dim1, r = xdnn::search_aligned_mat_mul(ctx.GetRawContext(),
max_seq, 0,
alpha1, 0,
seq_softmax_output, batch,
max_seq, max_seq,
group_padding_output, dim1,
dim1, max_seq,
batchgemm1_output, alpha1,
dim1); seq_softmax_output,
xdnn::search_sequence_pad_depad(ctx.GetRawContext(), max_seq,
top_data, group_padding_output,
batchgemm1_output, dim1,
offset_xpu, batchgemm1_output,
max_seq, dim1);
batch, CHECK_EQ(r, 0);
dim1, r = xdnn::search_sequence_pad_depad(ctx.GetRawContext(),
1); // is_depad = 1 top_data,
batchgemm1_output,
offset_xpu,
max_seq,
batch,
dim1,
1); // is_depad = 1
CHECK_EQ(r, 0);
} }
} // namespace xpu } // namespace xpu
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -29,12 +29,13 @@ void LookupTableCompute::Run() { ...@@ -29,12 +29,13 @@ void LookupTableCompute::Run() {
int embed_dim = param.W->dims()[1]; int embed_dim = param.W->dims()[1];
int r = xdnn::embedding<float, int64_t>( int r = xdnn::embedding<float, int64_t>(
ctx.GetRawContext(), /* context */ ctx.GetRawContext(), /* context */
num, /* num */ num, /* num */
param.Ids->data<int64_t>(), /* indices */ param.Ids->data<int64_t>(), /* indices */
embed_dim, /* embed_dim */ embed_dim, /* embed_dim */
param.W->data<float>(), /* table */ param.W->data<float>(), /* table */
param.Out->mutable_data<float>(TARGET(kXPU)) /* top */); param.Out->mutable_data<float>(TARGET(kXPU)), /* top */
param.padding_idx /* padding_idx */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -23,12 +23,15 @@ namespace kernels { ...@@ -23,12 +23,15 @@ namespace kernels {
namespace xpu { namespace xpu {
void MatchMatrixTensorCompute::PrepareForRun() { void MatchMatrixTensorCompute::PrepareForRun() {
wx_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); wx_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_r_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_l_cpu.reset(new int[64]); offset_r_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
offset_r_cpu.reset(new int[64]); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_l_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
offset_r_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void MatchMatrixTensorCompute::Run() { void MatchMatrixTensorCompute::Run() {
...@@ -76,25 +79,25 @@ void MatchMatrixTensorCompute::Run() { ...@@ -76,25 +79,25 @@ void MatchMatrixTensorCompute::Run() {
int* offset_r_xpu = reinterpret_cast<int*>(offset_r_xpu_guard_->addr_); int* offset_r_xpu = reinterpret_cast<int*>(offset_r_xpu_guard_->addr_);
int r = xdnn::gemm_int16_tmp_api<float, int16_t, float>( int r = xdnn::gemm_int16_tmp_api<float, int16_t, float>(
ctx.GetRawContext(), /* ctx */ ctx.GetRawContext(), /* ctx */
false, false, /* trans_a */
false, /* trans_a, trans_b */ false, /* trans_b */
x->dims()[0], x->dims()[0], /* m */
dim_t * dim_in, dim_t * dim_in, /* n */
dim_in, /* m, n, k */ dim_in, /* k */
1.0f, 1.0f, /* alpha */
bottom_l_data, bottom_l_data, /* data_a */
dim_in, /* alpha, data_a, lda */ dim_in, /* lda */
w_data, w_data, /* data_b */
dim_t * dim_in, dim_t * dim_in, /* ldb */
0.0f, /* data_b, ldb, beta */ 0.0f, /* beta */
bottom_l_trans_data, bottom_l_trans_data, /* data_c */
dim_t * dim_in, /* data_c, ldc */ dim_t * dim_in, /* ldc */
nullptr, /* bias */ nullptr, /* bias */
xdnn::Activation_t::LINEAR, xdnn::Activation_t::LINEAR, /* act */
0.0f, 0.0f, /* max_a */
w_max, w_max, /* max_b */
wx_max /* max_a, max_b, max_c */); wx_max /* max_c */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
int max_width = 0; int max_width = 0;
...@@ -110,14 +113,14 @@ void MatchMatrixTensorCompute::Run() { ...@@ -110,14 +113,14 @@ void MatchMatrixTensorCompute::Run() {
max_width = offset_r_cpu[i] - offset_r_cpu[i - 1]; max_width = offset_r_cpu[i] - offset_r_cpu[i - 1];
} }
} }
xpu_memcpy(offset_l_xpu, XPU_CALL(xpu_memcpy(offset_l_xpu,
offset_l_cpu.get(), offset_l_cpu.get(),
offset_l.size() * sizeof(int), offset_l.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(offset_r_xpu, XPU_CALL(xpu_memcpy(offset_r_xpu,
offset_r_cpu.get(), offset_r_cpu.get(),
offset_r.size() * sizeof(int), offset_r.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
r = xdnn::match_matrix_tensor(ctx.GetRawContext(), r = xdnn::match_matrix_tensor(ctx.GetRawContext(),
batch_size, batch_size,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -23,7 +23,8 @@ namespace kernels { ...@@ -23,7 +23,8 @@ namespace kernels {
namespace xpu { namespace xpu {
void SearchFcCompute::PrepareForRun() { void SearchFcCompute::PrepareForRun() {
maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(float)); maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(float), false /* use_l3 */);
} }
void SearchFcCompute::Run() { void SearchFcCompute::Run() {
...@@ -59,34 +60,34 @@ void SearchFcCompute::Run() { ...@@ -59,34 +60,34 @@ void SearchFcCompute::Run() {
float* maxs_xpu = reinterpret_cast<float*>(maxs_xpu_guard_->addr_); float* maxs_xpu = reinterpret_cast<float*>(maxs_xpu_guard_->addr_);
float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, w_max, 0.0f, 0.0f, 0.0f}; float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, w_max, 0.0f, 0.0f, 0.0f};
xpu_memcpy(maxs_xpu, XPU_CALL(xpu_memcpy(maxs_xpu,
&maxs_cpu[0], &maxs_cpu[0],
8 * sizeof(float), 8 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::findmax<float>( int r = xdnn::findmax<float>(
ctx.GetRawContext(), bottom_data, batch * _in, maxs_xpu); ctx.GetRawContext(), bottom_data, batch * _in, maxs_xpu);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
r = xdnn::gemm_int16_maxptr<float, int16_t, float>( r = xdnn::gemm_int16_maxptr<float, int16_t, float>(
ctx.GetRawContext(), /* ctx */ ctx.GetRawContext(), /* ctx */
false, false, /* trans_a */
true, /*trans_a, trans_b*/ true, /* trans_b */
batch, batch, /* m */
_out, _out, /* n */
_in, /*m, n, k*/ _in, /* k */
1.0f, 1.0f, /* alpha */
bottom_data, bottom_data, /* data_a */
_in, /*alpha, data_a, lda*/ _in, /* lda */
weights, weights, /* data_b */
_in, _in, /* ldb */
0.0f, /*data_b, ldb, beta*/ 0.0f, /* beta */
top_data, top_data, /* data_c */
_out, _out, /* ldc */
bias_data, /* data_c, ldc, bias*/ bias_data, /* bias */
act, act, /* act */
maxs_xpu, maxs_xpu, /* max_a */
maxs_xpu + 4, maxs_xpu + 4, /* max_b */
nullptr /*act, max_a, max_b, max_c*/); nullptr /* max_c */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -24,13 +24,16 @@ namespace kernels { ...@@ -24,13 +24,16 @@ namespace kernels {
namespace xpu { namespace xpu {
void SearchGrnnCompute::PrepareForRun() { void SearchGrnnCompute::PrepareForRun() {
offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(16 * sizeof(float)); new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SEQ_LEN * sizeof(int), false /* use_l3 */);
maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(16 * sizeof(float),
false /* use_l3 */);
idx_sorted_by_width_data_cpu.reset(new int[64]); idx_sorted_by_width_data_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
offset_cpu.reset(new int[64]); offset_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
new_offset_cpu.reset(new int[256]); new_offset_cpu.reset(new int[XPU_MAX_LOD_SEQ_LEN]);
} }
void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param, void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param,
...@@ -96,10 +99,10 @@ void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param, ...@@ -96,10 +99,10 @@ void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param,
layout_input->Resize({dim0, dim1}); layout_input->Resize({dim0, dim1});
} }
xpu_memcpy(idx_sorted_by_width->mutable_data<int>(TARGET(kXPU)), XPU_CALL(xpu_memcpy(idx_sorted_by_width->mutable_data<int>(TARGET(kXPU)),
idx_sorted_by_width_data_cpu.get(), idx_sorted_by_width_data_cpu.get(),
idx_sorted_by_width->numel() * sizeof(int), idx_sorted_by_width->numel() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
} }
void SearchGrnnCompute::Run() { void SearchGrnnCompute::Run() {
...@@ -156,14 +159,14 @@ void SearchGrnnCompute::Run() { ...@@ -156,14 +159,14 @@ void SearchGrnnCompute::Run() {
for (size_t i = 0; i < new_offset.size(); ++i) { for (size_t i = 0; i < new_offset.size(); ++i) {
new_offset_cpu[i] = new_offset[i]; new_offset_cpu[i] = new_offset[i];
} }
xpu_memcpy(offset_xpu, XPU_CALL(xpu_memcpy(offset_xpu,
offset_cpu.get(), offset_cpu.get(),
offset.size() * sizeof(int), offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(new_offset_xpu, XPU_CALL(xpu_memcpy(new_offset_xpu,
new_offset_cpu.get(), new_offset_cpu.get(),
new_offset.size() * sizeof(int), new_offset.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::search_seq2batch(ctx.GetRawContext(), int r = xdnn::search_seq2batch(ctx.GetRawContext(),
batch, batch,
...@@ -200,10 +203,10 @@ void SearchGrnnCompute::Run() { ...@@ -200,10 +203,10 @@ void SearchGrnnCompute::Run() {
0.0f, 0.0f,
0.0f, 0.0f,
0.0f}; 0.0f};
xpu_memcpy(maxs_xpu, XPU_CALL(xpu_memcpy(maxs_xpu,
maxs_cpu, maxs_cpu,
16 * sizeof(float), 16 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
r = xdnn::findmax<float>( r = xdnn::findmax<float>(
ctx.GetRawContext(), new_emb, cap_l * cap_e, maxs_xpu); ctx.GetRawContext(), new_emb, cap_l * cap_e, maxs_xpu);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
......
...@@ -37,44 +37,54 @@ void SequenceArithmeticCompute::Run() { ...@@ -37,44 +37,54 @@ void SequenceArithmeticCompute::Run() {
const auto* bottom_data1 = bottom1->data<float>(); const auto* bottom_data1 = bottom1->data<float>();
auto* top_data = top->mutable_data<float>(TARGET(kXPU)); auto* top_data = top->mutable_data<float>(TARGET(kXPU));
int r = 0;
switch (op_type) { switch (op_type) {
case 1: // addition: top[0] = bottom[0] + bottom[1] case 1: // addition: top[0] = bottom[0] + bottom[1]
if (len1 > len2) { if (len1 > len2) {
xdnn::elementwise_add( r = xdnn::elementwise_add(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(), CHECK_EQ(r, 0);
&top_data[len2], r = xdnn::memcpy_device(ctx.GetRawContext(),
&bottom_data0[len2], &top_data[len2],
(len1 - len2) * sizeof(float)); &bottom_data0[len2],
(len1 - len2) * sizeof(float));
CHECK_EQ(r, 0);
} else { } else {
xdnn::elementwise_add( r = xdnn::elementwise_add(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
CHECK_EQ(r, 0);
} }
break; break;
case 2: // substraction: top[0] = bottom[0] - bottom[1] case 2: // substraction: top[0] = bottom[0] - bottom[1]
if (len1 > len2) { if (len1 > len2) {
xdnn::elementwise_sub( r = xdnn::elementwise_sub(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(), CHECK_EQ(r, 0);
&top_data[len2], r = xdnn::memcpy_device(ctx.GetRawContext(),
&bottom_data0[len2], &top_data[len2],
(len1 - len2) * sizeof(float)); &bottom_data0[len2],
(len1 - len2) * sizeof(float));
CHECK_EQ(r, 0);
} else { } else {
xdnn::elementwise_sub( r = xdnn::elementwise_sub(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
CHECK_EQ(r, 0);
} }
break; break;
case 3: // multiplication: top[0] = bottom[0] * bottom[1] case 3: // multiplication: top[0] = bottom[0] * bottom[1]
if (len1 > len2) { if (len1 > len2) {
xdnn::elementwise_mul( r = xdnn::elementwise_mul(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2);
xdnn::memcpy_device(ctx.GetRawContext(), CHECK_EQ(r, 0);
&top_data[len2], r = xdnn::memcpy_device(ctx.GetRawContext(),
&bottom_data0[len2], &top_data[len2],
(len1 - len2) * sizeof(float)); &bottom_data0[len2],
(len1 - len2) * sizeof(float));
CHECK_EQ(r, 0);
} else { } else {
xdnn::elementwise_mul( r = xdnn::elementwise_mul(
ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1);
CHECK_EQ(r, 0);
} }
break; break;
default: default:
......
...@@ -23,11 +23,13 @@ namespace kernels { ...@@ -23,11 +23,13 @@ namespace kernels {
namespace xpu { namespace xpu {
void SequenceConcatCompute::PrepareForRun() { void SequenceConcatCompute::PrepareForRun() {
lod0_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); lod0_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
lod1_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
lod1_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
lod0_cpu.reset(new int[64]); lod0_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
lod1_cpu.reset(new int[64]); lod1_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
template <typename T> template <typename T>
...@@ -106,14 +108,14 @@ void SequenceConcatCompute::Run() { ...@@ -106,14 +108,14 @@ void SequenceConcatCompute::Run() {
for (int i = 0; i < lod1.size(); ++i) { for (int i = 0; i < lod1.size(); ++i) {
lod1_cpu[i] = lod1[i]; lod1_cpu[i] = lod1[i];
} }
xpu_memcpy(lod0_xpu, XPU_CALL(xpu_memcpy(lod0_xpu,
lod0_cpu.get(), lod0_cpu.get(),
lod0.size() * sizeof(int), lod0.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(lod1_xpu, XPU_CALL(xpu_memcpy(lod1_xpu,
lod1_cpu.get(), lod1_cpu.get(),
lod1.size() * sizeof(int), lod1.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::sequence_concat(ctx.GetRawContext(), int r = xdnn::sequence_concat(ctx.GetRawContext(),
xs[0]->data<float>(), xs[0]->data<float>(),
......
...@@ -23,8 +23,9 @@ namespace kernels { ...@@ -23,8 +23,9 @@ namespace kernels {
namespace xpu { namespace xpu {
void XPUSequencePoolCompute::PrepareForRun() { void XPUSequencePoolCompute::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
lod_cpu.reset(new int[64]); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void XPUSequencePoolCompute::Run() { void XPUSequencePoolCompute::Run() {
...@@ -55,10 +56,10 @@ void XPUSequencePoolCompute::Run() { ...@@ -55,10 +56,10 @@ void XPUSequencePoolCompute::Run() {
lod_cpu[i] = in_lod[i]; lod_cpu[i] = in_lod[i];
} }
int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_); int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_);
xpu_memcpy(lod_xpu, XPU_CALL(xpu_memcpy(lod_xpu,
lod_cpu.get(), lod_cpu.get(),
in_lod.size() * sizeof(int), in_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = int r =
xdnn::sequence_pooling_forward(ctx.GetRawContext(), xdnn::sequence_pooling_forward(ctx.GetRawContext(),
......
...@@ -23,8 +23,9 @@ namespace xpu { ...@@ -23,8 +23,9 @@ namespace xpu {
template <typename T, PrecisionType PType> template <typename T, PrecisionType PType>
void SequenceReverseCompute<T, PType>::PrepareForRun() { void SequenceReverseCompute<T, PType>::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
lod_cpu.reset(new int[64]); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
template <typename T, PrecisionType PType> template <typename T, PrecisionType PType>
...@@ -58,10 +59,10 @@ void SequenceReverseCompute<T, PType>::Run() { ...@@ -58,10 +59,10 @@ void SequenceReverseCompute<T, PType>::Run() {
lod_cpu[i] = lod[i]; lod_cpu[i] = lod[i];
} }
int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_); int* lod_xpu = reinterpret_cast<int*>(lod_xpu_guard_->addr_);
xpu_memcpy(lod_xpu, XPU_CALL(xpu_memcpy(lod_xpu,
lod_cpu.get(), lod_cpu.get(),
lod.size() * sizeof(int), lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::sequence_reverse(ctx.GetRawContext(), int r = xdnn::sequence_reverse(ctx.GetRawContext(),
batch_size, batch_size,
......
...@@ -23,10 +23,11 @@ namespace kernels { ...@@ -23,10 +23,11 @@ namespace kernels {
namespace xpu { namespace xpu {
void SequenceTopkAvgPoolingCompute::PrepareForRun() { void SequenceTopkAvgPoolingCompute::PrepareForRun() {
lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int)); lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
in_lod_cpu.reset(new int[64]); 4 * XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
row_lod_cpu.reset(new int[64]); in_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
col_lod_cpu.reset(new int[64]); row_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
col_lod_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void SequenceTopkAvgPoolingCompute::Run() { void SequenceTopkAvgPoolingCompute::Run() {
...@@ -81,22 +82,22 @@ void SequenceTopkAvgPoolingCompute::Run() { ...@@ -81,22 +82,22 @@ void SequenceTopkAvgPoolingCompute::Run() {
for (int i = 0; i < col_lod.size(); ++i) { for (int i = 0; i < col_lod.size(); ++i) {
col_lod_cpu[i] = col_lod[i]; col_lod_cpu[i] = col_lod[i];
} }
xpu_memcpy(in_lod_xpu, XPU_CALL(xpu_memcpy(in_lod_xpu,
in_lod_cpu.get(), in_lod_cpu.get(),
in_lod.size() * sizeof(int), in_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(row_lod_xpu, XPU_CALL(xpu_memcpy(row_lod_xpu,
row_lod_cpu.get(), row_lod_cpu.get(),
row_lod.size() * sizeof(int), row_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(col_lod_xpu, XPU_CALL(xpu_memcpy(col_lod_xpu,
col_lod_cpu.get(), col_lod_cpu.get(),
col_lod.size() * sizeof(int), col_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(topks_xpu, XPU_CALL(xpu_memcpy(topks_xpu,
topks.data(), topks.data(),
topks.size() * sizeof(int), topks.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::sequence_topk_avg_pooling(ctx.GetRawContext(), int r = xdnn::sequence_topk_avg_pooling(ctx.GetRawContext(),
in_data, in_data,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
namespace paddle { namespace paddle {
......
...@@ -25,9 +25,8 @@ void StackCompute::PrepareForRun() { ...@@ -25,9 +25,8 @@ void StackCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
int n = param.X.size(); int n = param.X.size();
void* x_ptr = nullptr; x_ptr_guard_ = TargetWrapperXPU::MallocScratchPad(
xpu_malloc(&x_ptr, n * 8 /* sizeof(__global__ float*) */); n * 8 /* sizeof(__global__ float*) */, false /* use_l3 */);
x_ptr_guard_.reset(x_ptr);
x_ptr_cpu_.reserve(n); x_ptr_cpu_.reserve(n);
} }
...@@ -47,14 +46,15 @@ void StackCompute::Run() { ...@@ -47,14 +46,15 @@ void StackCompute::Run() {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
x_ptr_cpu_[i] = param.X[i]->data<float>(); x_ptr_cpu_[i] = param.X[i]->data<float>();
} }
xpu_memcpy(x_ptr_guard_.get(), &x_ptr_cpu_[0], n * 8, XPU_HOST_TO_DEVICE); XPU_CALL(xpu_memcpy(
x_ptr_guard_->addr_, &x_ptr_cpu_[0], n * 8, XPU_HOST_TO_DEVICE));
int r = xdnn::stack_forward( int r = xdnn::stack_forward(
ctx.GetRawContext(), /* context */ ctx.GetRawContext(), /* context */
height, /* height */ height, /* height */
width, /* width */ width, /* width */
n, /* n */ n, /* n */
x_ptr_guard_.get(), /* x_ptr */ x_ptr_guard_->addr_, /* x_ptr */
param.Out->mutable_data<float>(TARGET(kXPU)) /* out */); param.Out->mutable_data<float>(TARGET(kXPU)) /* out */);
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <vector> #include <vector>
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -35,7 +34,7 @@ class StackCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> { ...@@ -35,7 +34,7 @@ class StackCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
virtual ~StackCompute() = default; virtual ~StackCompute() = default;
private: private:
std::unique_ptr<void, XPUFreeDeleter> x_ptr_guard_; XPUScratchPadGuard x_ptr_guard_;
std::vector<const float*> x_ptr_cpu_; std::vector<const float*> x_ptr_cpu_;
}; };
......
...@@ -53,10 +53,11 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -53,10 +53,11 @@ bool SubgraphEngine::BuildDeviceProgram() {
// IR graph // IR graph
subgraph::xpu::Graph graph; subgraph::xpu::Graph graph;
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
if (origin_program_.empty()) { if (!origin_program_) {
BuildOriginProgram(); BuildOriginProgram();
} }
for (auto& inst : origin_program_) { const auto& insts = origin_program_->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = const_cast<OpLite*>(inst.op()); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
...@@ -123,7 +124,7 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -123,7 +124,7 @@ bool SubgraphEngine::BuildDeviceProgram() {
auto node = graph.Get(device_inames_[i]); auto node = graph.Get(device_inames_[i]);
auto precision = node->precision(); auto precision = node->precision();
auto layout = node->layout(); auto layout = node->layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]); origin_itensors_[i] = exec_scope_->FindMutableTensor(device_inames_[i]);
CHECK(origin_itensors_[i]); CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims(); origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[XPU] Inputs[" << i << "] name: " << device_inames_[i] VLOG(3) << "[XPU] Inputs[" << i << "] name: " << device_inames_[i]
...@@ -147,7 +148,7 @@ bool SubgraphEngine::BuildDeviceProgram() { ...@@ -147,7 +148,7 @@ bool SubgraphEngine::BuildDeviceProgram() {
auto node = graph.Get(device_onames_[i]); auto node = graph.Get(device_onames_[i]);
auto precision = node->precision(); auto precision = node->precision();
auto layout = node->layout(); auto layout = node->layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]); origin_otensors_[i] = exec_scope_->FindMutableTensor(device_onames_[i]);
CHECK(origin_otensors_[i]); CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims(); origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[XPU] Outputs[" << i << "] name: " << device_onames_[i] VLOG(3) << "[XPU] Outputs[" << i << "] name: " << device_onames_[i]
...@@ -220,11 +221,11 @@ bool SubgraphEngine::LaunchDeviceProgram() { ...@@ -220,11 +221,11 @@ bool SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(), engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx, param.block_idx,
param.sub_block_desc, param.program_desc,
param.exec_scope,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names));
param.scope));
CHECK(engine_); CHECK(engine_);
} }
......
...@@ -31,12 +31,16 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -31,12 +31,16 @@ class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext *ctx, SubgraphEngine(KernelContext *ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, const std::shared_ptr<const cpp::ProgramDesc> &program_desc,
Scope *exec_scope,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names)
Scope *scope) : subgraph::Engine(ctx,
: subgraph::Engine( block_idx,
ctx, block_idx, block_desc, input_names, output_names, scope) {} program_desc,
exec_scope,
input_names,
output_names) {}
protected: protected:
bool PrepareWorkspaceForDeviceProgram() override; bool PrepareWorkspaceForDeviceProgram() override;
......
...@@ -23,10 +23,12 @@ namespace kernels { ...@@ -23,10 +23,12 @@ namespace kernels {
namespace xpu { namespace xpu {
void VarConv2DCompute::PrepareForRun() { void VarConv2DCompute::PrepareForRun() {
offset_x_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); offset_x_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
offset_y_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_x_cpu.reset(new int[64]); offset_y_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(
offset_y_cpu.reset(new int[64]); XPU_MAX_LOD_SIZE * sizeof(int), false /* use_l3 */);
offset_x_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
offset_y_cpu.reset(new int[XPU_MAX_LOD_SIZE]);
} }
void VarConv2DCompute::Run() { void VarConv2DCompute::Run() {
...@@ -94,14 +96,14 @@ void VarConv2DCompute::Run() { ...@@ -94,14 +96,14 @@ void VarConv2DCompute::Run() {
offset_x_cpu[i] = offset_x[i]; offset_x_cpu[i] = offset_x[i];
offset_y_cpu[i] = offset_y[i]; offset_y_cpu[i] = offset_y[i];
} }
xpu_memcpy(offset_x_xpu, XPU_CALL(xpu_memcpy(offset_x_xpu,
offset_x_cpu.get(), offset_x_cpu.get(),
(batch + 1) * sizeof(int), (batch + 1) * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
xpu_memcpy(offset_y_xpu, XPU_CALL(xpu_memcpy(offset_y_xpu,
offset_y_cpu.get(), offset_y_cpu.get(),
(batch + 1) * sizeof(int), (batch + 1) * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE); XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = xdnn::search_varconv<float, int16_t>(ctx.GetRawContext(), int r = xdnn::search_varconv<float, int16_t>(ctx.GetRawContext(),
batch, batch,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "lite/model_parser/base/block_desc.h" #include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/base/op_desc.h" #include "lite/model_parser/base/op_desc.h"
#include "lite/model_parser/base/program_desc.h" #include "lite/model_parser/base/program_desc.h"
#include "lite/model_parser/base/proto_desc.h"
#include "lite/model_parser/base/traits.h" #include "lite/model_parser/base/traits.h"
#include "lite/model_parser/base/var_desc.h" #include "lite/model_parser/base/var_desc.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
// The Index of first Block in Program. also called root block.
constexpr int kRootBlockIdx = 0;
// The Parent Index of root Block, this block does not exist.
constexpr int kNoneBlockIdx = -1;
} // namespace lite
} // namespace paddle
...@@ -57,21 +57,35 @@ class VectorView { ...@@ -57,21 +57,35 @@ class VectorView {
public: public:
typedef vector_view::VectorTraits<T, U> Traits; typedef vector_view::VectorTraits<T, U> Traits;
explicit VectorView(typename Traits::vector_type const* cvec) { explicit VectorView(typename Traits::vector_type const* cvec) {
CHECK(cvec);
cvec_ = cvec; cvec_ = cvec;
} }
typename Traits::subscript_return_type operator[](size_t i) const { typename Traits::subscript_return_type operator[](size_t i) const {
return cvec_->operator[](i); return cvec_->operator[](i);
} }
typename Traits::const_iterator begin() const { return cvec_->begin(); } typename Traits::const_iterator begin() const {
typename Traits::const_iterator end() const { return cvec_->end(); } if (!cvec_) {
size_t size() const { return cvec_->size(); } return typename Traits::const_iterator();
}
return cvec_->begin();
}
typename Traits::const_iterator end() const {
if (!cvec_) {
return typename Traits::const_iterator();
}
return cvec_->end();
}
size_t size() const {
if (!cvec_) {
return 0;
}
return cvec_->size();
}
operator std::vector<T>() const { operator std::vector<T>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<T> tmp; std::vector<T> tmp;
tmp.reserve(cvec_->size()); tmp.reserve(size());
for (auto val : *cvec_) { for (size_t i = 0; i < size(); ++i) {
tmp.push_back(val); tmp.push_back(cvec_->operator[](i));
} }
return tmp; return tmp;
} }
......
...@@ -234,7 +234,7 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { ...@@ -234,7 +234,7 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
template <> \ template <> \
void TransformBlockDescCppToAny<NT::T>(const cpp::T &cpp_desc, \ void TransformBlockDescCppToAny<NT::T>(const cpp::T &cpp_desc, \
NT::T *any_desc) { \ NT::T *any_desc) { \
auto desc = cpp_desc; \ const cpp::T &desc = cpp_desc; \
any_desc->SetIdx(desc.Idx()); \ any_desc->SetIdx(desc.Idx()); \
any_desc->SetParentIdx(desc.ParentIdx()); \ any_desc->SetParentIdx(desc.ParentIdx()); \
any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \
......
...@@ -15,20 +15,21 @@ ...@@ -15,20 +15,21 @@
#include "lite/model_parser/flatbuffers/io.h" #include "lite/model_parser/flatbuffers/io.h"
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace fbs { namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog) { void LoadModel(const std::string& path, ProgramDesc* prog) {
CHECK(prog);
FILE* file = fopen(path.c_str(), "rb"); FILE* file = fopen(path.c_str(), "rb");
fseek(file, 0, SEEK_END); fseek(file, 0, SEEK_END);
int64_t size = ftell(file); int64_t length = ftell(file);
rewind(file); rewind(file);
char* data = new char[size]; std::vector<char> buf(length);
size = fread(data, 1, size, file); CHECK(fread(buf.data(), 1, length, file));
fclose(file); fclose(file);
std::unique_ptr<char[]> buf(data);
prog->Init(std::move(buf)); prog->Init(std::move(buf));
} }
......
...@@ -62,7 +62,7 @@ class OpDesc : public OpDescAPI { ...@@ -62,7 +62,7 @@ class OpDesc : public OpDescAPI {
std::vector<std::string> Output(const std::string& param) const override { std::vector<std::string> Output(const std::string& param) const override {
const auto& var = desc_->outputs()->LookupByKey(param.c_str()); const auto& var = desc_->outputs()->LookupByKey(param.c_str());
std::vector<std::string> args_vec; std::vector<std::string> args_vec;
if (var->arguments()) { if (var && var->arguments()) {
args_vec.reserve(var->arguments()->size()); args_vec.reserve(var->arguments()->size());
for (const auto& out : *var->arguments()) { for (const auto& out : *var->arguments()) {
args_vec.push_back(out->str()); args_vec.push_back(out->str());
...@@ -169,8 +169,7 @@ class OpDesc : public OpDescAPI { ...@@ -169,8 +169,7 @@ class OpDesc : public OpDescAPI {
} }
bool HasOutput(const std::string& param) const { bool HasOutput(const std::string& param) const {
NotImplemented(); return !Output(param).empty();
return false;
} }
const std::map<std::string, Any>& attrs() const { const std::map<std::string, Any>& attrs() const {
......
...@@ -29,16 +29,25 @@ namespace fbs { ...@@ -29,16 +29,25 @@ namespace fbs {
class ProgramDesc : public ProgramDescAPI { class ProgramDesc : public ProgramDescAPI {
public: public:
ProgramDesc() = default; ProgramDesc() = default;
explicit ProgramDesc(std::unique_ptr<const char[]> buf) { explicit ProgramDesc(const std::vector<char>& buf) { Init(buf); }
Init(std::move(buf)); explicit ProgramDesc(std::vector<char>&& buf) {
Init(std::forward<std::vector<char>>(buf));
} }
size_t BlocksSize() const override { return desc_->blocks()->size(); } void Init(const std::vector<char>& buf) {
CHECK(buf.data());
buf_ = buf;
InitProgramDesc();
}
void Init(std::unique_ptr<const char[]> buf) { void Init(std::vector<char>&& buf) {
CHECK(buf.get() != nullptr); CHECK(buf.data());
buf_ = std::move(buf); buf_ = std::move(buf);
desc_ = proto::GetProgramDesc(buf_.get()); InitProgramDesc();
}
void InitProgramDesc() {
desc_ = proto::GetProgramDesc(buf_.data());
blocks_.reserve(BlocksSize()); blocks_.reserve(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) { for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx))); blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx)));
...@@ -46,12 +55,12 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -46,12 +55,12 @@ class ProgramDesc : public ProgramDescAPI {
} }
void CopyFrom(const ProgramDesc& other) { void CopyFrom(const ProgramDesc& other) {
size_t length = strlen(static_cast<const char*>(other.raw_buf())); buf_ = other.buf();
std::unique_ptr<char[]> buf(new char[length]); Init(buf_);
memcpy(buf.get(), other.raw_buf(), length);
Init(std::move(buf));
} }
size_t BlocksSize() const override { return desc_->blocks()->size(); }
template <typename T> template <typename T>
T const* GetBlock(int32_t idx) const; T const* GetBlock(int32_t idx) const;
...@@ -72,11 +81,11 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -72,11 +81,11 @@ class ProgramDesc : public ProgramDescAPI {
proto::ProgramDesc const* raw_desc() const { return desc_; } proto::ProgramDesc const* raw_desc() const { return desc_; }
const void* raw_buf() const { return buf_.get(); } const std::vector<char>& buf() const { return buf_; }
private: private:
proto::ProgramDesc const* desc_; proto::ProgramDesc const* desc_;
std::unique_ptr<const char[]> buf_; std::vector<char> buf_;
std::vector<BlockDesc> blocks_; std::vector<BlockDesc> blocks_;
private: private:
......
...@@ -51,6 +51,7 @@ struct FBSStrIterator { ...@@ -51,6 +51,7 @@ struct FBSStrIterator {
flatbuffers::Offset<flatbuffers::String>>::return_type> flatbuffers::Offset<flatbuffers::String>>::return_type>
VI; VI;
FBSStrIterator() = default;
explicit FBSStrIterator(const VI& iter) { iter_ = iter; } explicit FBSStrIterator(const VI& iter) { iter_ = iter; }
const VI& raw_iter() const { return iter_; } const VI& raw_iter() const { return iter_; }
...@@ -104,20 +105,21 @@ class VectorView<std::string, Flatbuffers> { ...@@ -104,20 +105,21 @@ class VectorView<std::string, Flatbuffers> {
explicit VectorView(typename Traits::vector_type const* cvec) { explicit VectorView(typename Traits::vector_type const* cvec) {
cvec_ = cvec; cvec_ = cvec;
} }
std::string operator[](size_t i) const { std::string operator[](size_t i) const { return cvec_->operator[](i)->str(); }
CHECK(cvec_);
return cvec_->operator[](i)->str();
}
vector_view::FBSStrIterator begin() const { vector_view::FBSStrIterator begin() const {
CHECK(cvec_); if (!cvec_) {
return vector_view::FBSStrIterator();
}
return vector_view::FBSStrIterator(cvec_->begin()); return vector_view::FBSStrIterator(cvec_->begin());
} }
vector_view::FBSStrIterator end() const { vector_view::FBSStrIterator end() const {
CHECK(cvec_); if (!cvec_) {
return vector_view::FBSStrIterator();
}
return vector_view::FBSStrIterator(cvec_->end()); return vector_view::FBSStrIterator(cvec_->end());
} }
size_t size() const { size_t size() const {
if (cvec_ == nullptr) { if (!cvec_) {
return 0; return 0;
} }
return cvec_->size(); return cvec_->size();
...@@ -126,10 +128,8 @@ class VectorView<std::string, Flatbuffers> { ...@@ -126,10 +128,8 @@ class VectorView<std::string, Flatbuffers> {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp; std::vector<std::string> tmp;
tmp.reserve(size()); tmp.reserve(size());
if (cvec_ != nullptr) { for (size_t i = 0; i < size(); ++i) {
for (auto val : *cvec_) { tmp.push_back(cvec_->operator[](i)->str());
tmp.push_back(val->str());
}
} }
return tmp; return tmp;
} }
......
...@@ -3,4 +3,4 @@ lite_cc_library(cpp_var_desc SRCS var_desc.cc) ...@@ -3,4 +3,4 @@ lite_cc_library(cpp_var_desc SRCS var_desc.cc)
lite_cc_library(cpp_block_desc SRCS block_desc.cc) lite_cc_library(cpp_block_desc SRCS block_desc.cc)
lite_cc_library(cpp_program_desc SRCS program_desc.cc) lite_cc_library(cpp_program_desc SRCS program_desc.cc)
set(cpp_wrapper cpp_op_desc cpp_var_desc cpp_block_desc cpp_program_desc PARENT_SCOPE) set(cpp_wrapper cpp_program_desc cpp_block_desc cpp_var_desc cpp_op_desc PARENT_SCOPE)
...@@ -294,9 +294,9 @@ const proto::VarType::TensorDesc &VarDesc::tensor_desc() const { ...@@ -294,9 +294,9 @@ const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
case proto::VarType::LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
return desc_->type().tensor_array().tensor(); return desc_->type().tensor_array().tensor();
default: default:
LOG(FATAL) LOG(WARNING) << "Getting 'tensor_desc' is not supported by the type("
<< "Getting 'tensor_desc' is not supported by the type of var %s." << static_cast<int>(desc_->type().type()) << ") of var "
<< this->Name(); << this->Name();
} }
return framework::proto::VarDesc().type().lod_tensor().tensor(); return framework::proto::VarDesc().type().lod_tensor().tensor();
} }
...@@ -312,10 +312,9 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const { ...@@ -312,10 +312,9 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
} }
return res; return res;
default: default:
LOG(FATAL) LOG(WARNING) << "Getting 'tensor_descs' is not supported by the type("
<< "Getting 'tensor_descs' is not supported by the type of var " << static_cast<int>(desc_->type().type()) << ") of var "
"%s." << this->Name();
<< this->Name();
} }
return std::vector<proto::VarType::TensorDesc>(); return std::vector<proto::VarType::TensorDesc>();
} }
......
...@@ -115,6 +115,7 @@ add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS}) ...@@ -115,6 +115,7 @@ add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS})
add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${op_DEPS}) add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${op_DEPS})
add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS}) add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS})
add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS}) add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS})
add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
...@@ -88,6 +88,78 @@ bool XPUMmdnnBidEmbGrnnAttOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -88,6 +88,78 @@ bool XPUMmdnnBidEmbGrnnAttOp::AttachImpl(const cpp::OpDesc& op_desc,
return true; return true;
} }
bool XPUMmdnnBidEmbGrnnAttOp2::CheckShape() const { return true; }
bool XPUMmdnnBidEmbGrnnAttOp2::InferShapeImpl() const {
auto& id_dims = param_.id0->dims();
auto& id_lod = param_.id0->lod()[0];
auto& emb_tbl_dims = param_.emb_tbl->dims();
auto& grnn_wh_dims = param_.grnn_rv_wh->dims();
param_.emb0_out->Resize({id_dims[0], emb_tbl_dims[1]});
param_.emb0_out->set_lod({id_lod});
param_.grnn_fw_pool_out->Resize(
{(int64_t)id_lod.size() - 1, grnn_wh_dims[2]});
param_.grnn_rv_pool_out->Resize(
{(int64_t)id_lod.size() - 1, grnn_wh_dims[2]});
param_.att_pool_out->Resize(
{(int64_t)id_lod.size() - 1, 2 * grnn_wh_dims[2]});
param_.concat_3in1_out->Resize({id_dims[0], 3 * grnn_wh_dims[2]});
param_.concat_3in1_out->set_lod({id_lod});
param_.emb_fw_out->Resize({id_dims[0], emb_tbl_dims[1]});
param_.emb_fw_out->set_lod({id_lod});
return true;
}
bool XPUMmdnnBidEmbGrnnAttOp2::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.id0 =
scope->FindVar(op_desc.Input("id0").front())->GetMutable<lite::Tensor>();
param_.id1 =
scope->FindVar(op_desc.Input("id1").front())->GetMutable<lite::Tensor>();
param_.emb_tbl = scope->FindVar(op_desc.Input("emb_tbl").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wi = scope->FindVar(op_desc.Input("grnn_fw_wi").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_wh = scope->FindVar(op_desc.Input("grnn_rv_wh").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_wi = scope->FindVar(op_desc.Input("grnn_rv_wi").front())
->GetMutable<lite::Tensor>();
param_.att_fc_w = scope->FindVar(op_desc.Input("att_fc_w").front())
->GetMutable<lite::Tensor>();
param_.att_fc_b = scope->FindVar(op_desc.Input("att_fc_b").front())
->GetMutable<lite::Tensor>();
param_.emb0_out = scope->FindVar(op_desc.Output("emb0_out").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_pool_out =
scope->FindVar(op_desc.Output("grnn_fw_pool_out").front())
->GetMutable<lite::Tensor>();
param_.grnn_rv_pool_out =
scope->FindVar(op_desc.Output("grnn_rv_pool_out").front())
->GetMutable<lite::Tensor>();
param_.att_pool_out = scope->FindVar(op_desc.Output("att_pool_out").front())
->GetMutable<lite::Tensor>();
param_.concat_3in1_out =
scope->FindVar(op_desc.Output("concat_3in1_out").front())
->GetMutable<lite::Tensor>();
param_.emb_fw_out = scope->FindVar(op_desc.Output("emb_fw_out").front())
->GetMutable<lite::Tensor>();
param_.grnn_fw_wh_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_fw_wh_maxs");
param_.grnn_fw_wi_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_fw_wi_maxs");
param_.grnn_rv_wh_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_rv_wh_maxs");
param_.grnn_rv_wi_maxs =
op_desc.GetAttr<std::vector<float>>("grnn_rv_wi_maxs");
param_.att_fc_w_max = op_desc.GetAttr<float>("att_fc_w_max");
return true;
}
bool XPUMmdnnBidEmbAttOp::CheckShape() const { return true; } bool XPUMmdnnBidEmbAttOp::CheckShape() const { return true; }
bool XPUMmdnnBidEmbAttOp::InferShapeImpl() const { bool XPUMmdnnBidEmbAttOp::InferShapeImpl() const {
...@@ -157,6 +229,7 @@ bool XPUMmdnnMatchConvTopkOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -157,6 +229,7 @@ bool XPUMmdnnMatchConvTopkOp::AttachImpl(const cpp::OpDesc& op_desc,
param_.input_w_max = op_desc.GetAttr<float>("input_w_max"); param_.input_w_max = op_desc.GetAttr<float>("input_w_max");
param_.conv_w_max = op_desc.GetAttr<float>("conv_w_max"); param_.conv_w_max = op_desc.GetAttr<float>("conv_w_max");
param_.topks = op_desc.GetAttr<std::vector<int>>("topks"); param_.topks = op_desc.GetAttr<std::vector<int>>("topks");
param_.output_channel = op_desc.GetAttr<int>("output_channel");
param_.channel_num = op_desc.GetAttr<int>("channel_num"); param_.channel_num = op_desc.GetAttr<int>("channel_num");
param_.dim_t = op_desc.GetAttr<int>("dim_t"); param_.dim_t = op_desc.GetAttr<int>("dim_t");
return true; return true;
...@@ -182,10 +255,10 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -182,10 +255,10 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc,
auto t = scope->FindVar(name)->GetMutable<lite::Tensor>(); auto t = scope->FindVar(name)->GetMutable<lite::Tensor>();
param_.concat_7in1_x.push_back(t); param_.concat_7in1_x.push_back(t);
} }
param_.concat_2in1_x.clear(); param_.concat_topk_x.clear();
for (auto& name : op_desc.Input("concat_2in1_x")) { for (auto& name : op_desc.Input("concat_topk_x")) {
auto t = scope->FindVar(name)->GetMutable<lite::Tensor>(); auto t = scope->FindVar(name)->GetMutable<lite::Tensor>();
param_.concat_2in1_x.push_back(t); param_.concat_topk_x.push_back(t);
} }
param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front()) param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front())
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
...@@ -231,6 +304,8 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -231,6 +304,8 @@ bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc,
REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att, REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att,
paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp); paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp);
REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att2,
paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp2);
REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_att, REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_att,
paddle::lite::operators::XPUMmdnnBidEmbAttOp); paddle::lite::operators::XPUMmdnnBidEmbAttOp);
REGISTER_LITE_OP(__xpu__mmdnn_match_conv_topk, REGISTER_LITE_OP(__xpu__mmdnn_match_conv_topk,
......
...@@ -41,6 +41,29 @@ class XPUMmdnnBidEmbGrnnAttOp : public OpLite { ...@@ -41,6 +41,29 @@ class XPUMmdnnBidEmbGrnnAttOp : public OpLite {
mutable XPUMmdnnBidEmbGrnnAttParam param_; mutable XPUMmdnnBidEmbGrnnAttParam param_;
}; };
class XPUMmdnnBidEmbGrnnAttOp2 : public OpLite {
public:
XPUMmdnnBidEmbGrnnAttOp2() {}
explicit XPUMmdnnBidEmbGrnnAttOp2(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "XPUMmdnnBidEmbGrnnAttOp2";
}
private:
mutable XPUMmdnnBidEmbGrnnAttParam2 param_;
};
class XPUMmdnnBidEmbAttOp : public OpLite { class XPUMmdnnBidEmbAttOp : public OpLite {
public: public:
XPUMmdnnBidEmbAttOp() {} XPUMmdnnBidEmbAttOp() {}
......
...@@ -21,15 +21,15 @@ namespace lite { ...@@ -21,15 +21,15 @@ namespace lite {
namespace operators { namespace operators {
bool AssignOpLite::CheckShape() const { bool AssignOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.X || param_.X_array);
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out || param_.Out_array);
return true; return true;
} }
bool AssignOpLite::InferShapeImpl() const { bool AssignOpLite::InferShapeImpl() const {
if (param_.X != nullptr) { if (param_.X) {
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
} else if (param_.X_array != nullptr) { } else if (param_.X_array) {
param_.Out_array->resize(param_.Out_array->size()); param_.Out_array->resize(param_.Out_array->size());
} else { } else {
LOG(FATAL) << "x or x_array must be set."; LOG(FATAL) << "x or x_array must be set.";
......
...@@ -20,35 +20,37 @@ namespace paddle { ...@@ -20,35 +20,37 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
bool ConditionalBlockOpLite::CheckShape() const { bool ConditionalBlockOp::CheckShape() const {
CHECK_OR_FALSE(param_.cond); CHECK_OR_FALSE(param_.cond);
CHECK_OR_FALSE(param_.sub_block); CHECK_OR_FALSE(param_.program_desc);
CHECK_OR_FALSE(param_.scope); CHECK_OR_FALSE(param_.exec_scope);
return true; return true;
} }
bool ConditionalBlockOpLite::InferShapeImpl() const { return true; } bool ConditionalBlockOp::InferShapeImpl() const { return true; }
bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, bool ConditionalBlockOp::AttachImpl(const cpp::OpDesc& op_desc, Scope* scope) {
lite::Scope *scope) {
auto condition = op_desc.Input("Cond").front(); auto condition = op_desc.Input("Cond").front();
param_.cond = scope->FindVar(condition)->GetMutable<lite::Tensor>(); param_.cond = scope->FindVar(condition)->GetMutable<lite::Tensor>();
auto inputs = op_desc.Input("Input"); auto inputs = op_desc.Input("Input");
for (auto var : inputs) { for (const auto& input : inputs) {
param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); auto* var = scope->FindVar(input);
CHECK(var);
param_.inputs.push_back(var->GetMutable<lite::Tensor>());
} }
auto outs = op_desc.Output("Out"); auto outs = op_desc.Output("Out");
for (auto var : outs) { for (const auto& out : outs) {
param_.outs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); auto* var = scope->FindVar(out);
CHECK(var);
param_.outs.push_back(var->GetMutable<lite::Tensor>());
} }
param_.is_scalar_condition = op_desc.GetAttr<bool>("is_scalar_condition"); param_.is_scalar_condition = op_desc.GetAttr<bool>("is_scalar_condition");
// obtain sub_block in core program.cc // obtain sub_block in core program.cc
param_.sub_block = sub_block_; CHECK(param_.program_desc);
param_.scope = scope; param_.block_idx = op_desc.GetAttr<int32_t>("sub_block");
CHECK_GE(param_.block_idx, 0);
param_.exec_scope = scope;
CHECK(param_.exec_scope);
return true; return true;
} }
...@@ -57,4 +59,4 @@ bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, ...@@ -57,4 +59,4 @@ bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc,
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(conditional_block, REGISTER_LITE_OP(conditional_block,
paddle::lite::operators::ConditionalBlockOpLite); paddle::lite::operators::ConditionalBlockOp);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
...@@ -23,27 +24,30 @@ namespace paddle { ...@@ -23,27 +24,30 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class ConditionalBlockOpLite : public OpLite { class ConditionalBlockOp : public OpLite {
public: public:
ConditionalBlockOpLite() {} ConditionalBlockOp() {}
explicit ConditionalBlockOpLite(const std::string &op_type) explicit ConditionalBlockOp(const std::string &op_type) : OpLite(op_type) {}
: OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShapeImpl() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "conditional_block"; } std::string DebugString() const override { return "conditional_block"; }
void SetSubBlock(cpp::BlockDesc *desc) { sub_block_ = desc; } void SetProgramDesc(std::shared_ptr<const cpp::ProgramDesc> program_desc) {
param_.program_desc = program_desc;
}
std::shared_ptr<const cpp::ProgramDesc> GetProgramDesc() {
return param_.program_desc;
}
private: private:
mutable ConditionalBlockParam param_; mutable ConditionalBlockParam param_;
cpp::BlockDesc *sub_block_;
}; };
} // namespace operators } // namespace operators
......
...@@ -90,9 +90,9 @@ struct SubgraphParam : ParamBase { ...@@ -90,9 +90,9 @@ struct SubgraphParam : ParamBase {
std::vector<std::string> output_names{}; std::vector<std::string> output_names{};
std::vector<std::string> input_data_names{}; std::vector<std::string> input_data_names{};
std::vector<std::string> output_data_names{}; std::vector<std::string> output_data_names{};
int sub_block_idx{-1}; int block_idx{-1};
cpp::BlockDesc* sub_block_desc{nullptr}; std::shared_ptr<const cpp::ProgramDesc> program_desc{nullptr};
Scope* scope{nullptr}; Scope* exec_scope{nullptr};
}; };
/// -------------------------- NN operators ------------------------------------ /// -------------------------- NN operators ------------------------------------
...@@ -939,11 +939,10 @@ struct CompareParam : ParamBase { ...@@ -939,11 +939,10 @@ struct CompareParam : ParamBase {
}; };
struct WhileParam : ParamBase { struct WhileParam : ParamBase {
Scope* scope{};
Tensor* cond{}; Tensor* cond{};
cpp::BlockDesc* sub_block{}; int block_idx{-1};
std::vector<Tensor*> x{}; std::shared_ptr<const cpp::ProgramDesc> program_desc{nullptr};
std::vector<Tensor*> outs{}; Scope* exec_scope{nullptr};
}; };
struct TopkParam : ParamBase { struct TopkParam : ParamBase {
...@@ -1454,10 +1453,11 @@ struct MergeLodTensorParam : ParamBase { ...@@ -1454,10 +1453,11 @@ struct MergeLodTensorParam : ParamBase {
struct ConditionalBlockParam : ParamBase { struct ConditionalBlockParam : ParamBase {
const lite::Tensor* cond{}; const lite::Tensor* cond{};
std::vector<lite::Tensor*> x{}; std::vector<lite::Tensor*> inputs{};
std::vector<lite::Tensor*> outs{}; std::vector<lite::Tensor*> outs{};
cpp::BlockDesc* sub_block{}; int block_idx{-1};
Scope* scope{}; std::shared_ptr<const cpp::ProgramDesc> program_desc{nullptr};
Scope* exec_scope{nullptr};
bool is_scalar_condition{}; bool is_scalar_condition{};
}; };
...@@ -1627,11 +1627,36 @@ struct XPUMmdnnBidEmbGrnnAttParam : ParamBase { ...@@ -1627,11 +1627,36 @@ struct XPUMmdnnBidEmbGrnnAttParam : ParamBase {
std::vector<float> grnn_rv_wi_maxs; std::vector<float> grnn_rv_wi_maxs;
float att_fc_w_max{0.0f}; float att_fc_w_max{0.0f};
lite::Tensor* grnn_fw_pool_out{}; // 1 lite::Tensor* grnn_fw_pool_out{};
lite::Tensor* grnn_rv_pool_out{}; // 2 lite::Tensor* grnn_rv_pool_out{};
lite::Tensor* att_pool_out{}; // 3 lite::Tensor* att_pool_out{};
lite::Tensor* concat_3in1_out{}; // 4 lite::Tensor* concat_3in1_out{};
lite::Tensor* emb_fw_out{}; // 5 lite::Tensor* emb_fw_out{};
};
struct XPUMmdnnBidEmbGrnnAttParam2 : ParamBase {
lite::Tensor* id0{};
lite::Tensor* id1{};
lite::Tensor* emb_tbl{};
lite::Tensor* grnn_fw_wh{};
lite::Tensor* grnn_fw_wi{};
lite::Tensor* grnn_rv_wh{};
lite::Tensor* grnn_rv_wi{};
lite::Tensor* att_fc_w{};
lite::Tensor* att_fc_b{};
std::vector<float> grnn_fw_wh_maxs;
std::vector<float> grnn_fw_wi_maxs;
std::vector<float> grnn_rv_wh_maxs;
std::vector<float> grnn_rv_wi_maxs;
float att_fc_w_max{0.0f};
lite::Tensor* emb0_out{};
lite::Tensor* grnn_fw_pool_out{};
lite::Tensor* grnn_rv_pool_out{};
lite::Tensor* att_pool_out{};
lite::Tensor* concat_3in1_out{};
lite::Tensor* emb_fw_out{};
}; };
struct XPUMmdnnBidEmbAttParam : ParamBase { struct XPUMmdnnBidEmbAttParam : ParamBase {
...@@ -1643,8 +1668,8 @@ struct XPUMmdnnBidEmbAttParam : ParamBase { ...@@ -1643,8 +1668,8 @@ struct XPUMmdnnBidEmbAttParam : ParamBase {
float att_fc_w_max{0.0f}; float att_fc_w_max{0.0f};
lite::Tensor* att_pool_out{}; // 1 lite::Tensor* att_pool_out{};
lite::Tensor* emb_fw_out{}; // 2 lite::Tensor* emb_fw_out{};
}; };
struct XPUMmdnnMatchConvTopkParam : ParamBase { struct XPUMmdnnMatchConvTopkParam : ParamBase {
...@@ -1656,6 +1681,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase { ...@@ -1656,6 +1681,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase {
float input_w_max{0.0f}; float input_w_max{0.0f};
float conv_w_max{0.0f}; float conv_w_max{0.0f};
std::vector<int> topks; std::vector<int> topks;
int output_channel{0};
int channel_num{0}; int channel_num{0};
int dim_t{0}; int dim_t{0};
...@@ -1664,7 +1690,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase { ...@@ -1664,7 +1690,7 @@ struct XPUMmdnnMatchConvTopkParam : ParamBase {
struct XPUMmdnnMergeAllParam : ParamBase { struct XPUMmdnnMergeAllParam : ParamBase {
std::vector<lite::Tensor*> concat_7in1_x; std::vector<lite::Tensor*> concat_7in1_x;
std::vector<lite::Tensor*> concat_2in1_x; std::vector<lite::Tensor*> concat_topk_x;
lite::Tensor* grnn_fw_wh{}; lite::Tensor* grnn_fw_wh{};
lite::Tensor* grnn_fw_wi{}; lite::Tensor* grnn_fw_wi{};
lite::Tensor* grnn_rv_wh{}; lite::Tensor* grnn_rv_wh{};
...@@ -1753,6 +1779,22 @@ struct ClipParam : ParamBase { ...@@ -1753,6 +1779,22 @@ struct ClipParam : ParamBase {
float max{}; float max{};
}; };
struct PrintParam : ParamBase {
const lite::Tensor* in{};
lite::Tensor* out{};
std::string name;
int first_n{-1};
std::string message;
int summarize{20};
bool print_tensor_name{true};
bool print_tensor_type{true};
bool print_tensor_shape{true};
bool print_tensor_lod{true};
bool print_tensor_layout{true};
std::string print_phase;
bool is_forward{true};
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/print_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool PrintOp::CheckShape() const {
CHECK_OR_FALSE(param_.in);
CHECK_OR_FALSE(param_.out);
return true;
}
bool PrintOp::InferShapeImpl() const {
param_.out->set_lod(param_.in->lod());
param_.out->Resize(param_.in->dims());
return true;
}
bool PrintOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
param_.name = op_desc.Input("In").front();
param_.in = scope->FindTensor(param_.name);
param_.out = scope->FindMutableTensor(op_desc.Output("Out").front());
param_.first_n = op_desc.GetAttr<int32_t>("first_n");
param_.message = op_desc.GetAttr<std::string>("message");
param_.summarize = op_desc.GetAttr<int32_t>("summarize");
param_.print_tensor_name = op_desc.GetAttr<bool>("print_tensor_name");
param_.print_tensor_type = op_desc.GetAttr<bool>("print_tensor_type");
param_.print_tensor_shape = op_desc.GetAttr<bool>("print_tensor_shape");
param_.print_tensor_lod = op_desc.GetAttr<bool>("print_tensor_lod");
param_.print_tensor_layout = op_desc.GetAttr<bool>("print_tensor_layout");
param_.print_phase = op_desc.GetAttr<std::string>("print_phase");
param_.is_forward = op_desc.GetAttr<bool>("is_forward");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(print, paddle::lite::operators::PrintOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class PrintOp : public OpLite {
public:
PrintOp() {}
explicit PrintOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "print"; }
private:
mutable PrintParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -39,10 +39,11 @@ bool SubgraphOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { ...@@ -39,10 +39,11 @@ bool SubgraphOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
op_desc.GetAttr<std::vector<std::string>>("input_data_names"); op_desc.GetAttr<std::vector<std::string>>("input_data_names");
param_.output_data_names = param_.output_data_names =
op_desc.GetAttr<std::vector<std::string>>("output_data_names"); op_desc.GetAttr<std::vector<std::string>>("output_data_names");
CHECK(param_.sub_block_desc); CHECK(param_.program_desc);
param_.sub_block_idx = op_desc.GetAttr<int32_t>("sub_block"); param_.block_idx = op_desc.GetAttr<int32_t>("sub_block");
param_.scope = scope; CHECK_GE(param_.block_idx, 0);
CHECK(param_.scope); param_.exec_scope = scope;
CHECK(param_.exec_scope);
return true; return true;
} }
......
...@@ -13,14 +13,11 @@ ...@@ -13,14 +13,11 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/scope.h" #include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
namespace paddle { namespace paddle {
...@@ -37,14 +34,18 @@ class SubgraphOp : public OpLite { ...@@ -37,14 +34,18 @@ class SubgraphOp : public OpLite {
bool InferShapeImpl() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &op_desc, Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "subgraph"; } std::string DebugString() const override { return "subgraph"; }
void SetSubBlock(cpp::BlockDesc *desc) { param_.sub_block_desc = desc; } void SetProgramDesc(std::shared_ptr<const cpp::ProgramDesc> program_desc) {
cpp::BlockDesc *GetSubBlock() { return param_.sub_block_desc; } param_.program_desc = program_desc;
}
std::shared_ptr<const cpp::ProgramDesc> GetProgramDesc() {
return param_.program_desc;
}
private: private:
mutable SubgraphParam param_; mutable SubgraphParam param_;
......
...@@ -20,31 +20,23 @@ namespace paddle { ...@@ -20,31 +20,23 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
bool WhileOpLite::CheckShape() const { bool WhileOp::CheckShape() const {
CHECK_OR_FALSE(param_.sub_block);
CHECK_OR_FALSE(param_.scope);
CHECK_OR_FALSE(param_.cond); CHECK_OR_FALSE(param_.cond);
CHECK_OR_FALSE(param_.program_desc);
CHECK_OR_FALSE(param_.exec_scope);
return true; return true;
} }
bool WhileOpLite::InferShapeImpl() const { return true; } bool WhileOp::InferShapeImpl() const { return true; }
bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto inputs = op_desc.Input("X");
auto outs = op_desc.Output("Out");
for (auto var : inputs) {
// param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
for (auto var : outs) {
// param_.outs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
param_.sub_block = sub_block_;
bool WhileOp::AttachImpl(const cpp::OpDesc &op_desc, Scope *scope) {
auto condition = op_desc.Input("Condition"); auto condition = op_desc.Input("Condition");
param_.cond = scope->FindVar(condition[0])->GetMutable<lite::Tensor>(); param_.cond = scope->FindVar(condition[0])->GetMutable<lite::Tensor>();
param_.scope = scope; CHECK(param_.program_desc);
param_.block_idx = op_desc.GetAttr<int32_t>("sub_block");
CHECK_GE(param_.block_idx, 0);
param_.exec_scope = scope;
CHECK(param_.exec_scope);
return true; return true;
} }
...@@ -52,4 +44,4 @@ bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -52,4 +44,4 @@ bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(while, paddle::lite::operators::WhileOpLite); REGISTER_LITE_OP(while, paddle::lite::operators::WhileOp);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
...@@ -23,24 +24,30 @@ namespace paddle { ...@@ -23,24 +24,30 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class WhileOpLite : public OpLite { class WhileOp : public OpLite {
public: public:
WhileOpLite() {} WhileOp() {}
explicit WhileOpLite(const std::string &op_type) : OpLite(op_type) {} explicit WhileOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShapeImpl() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "while"; } std::string DebugString() const override { return "while"; }
void SetSubBlock(cpp::BlockDesc *desc) { sub_block_ = desc; }
void SetProgramDesc(std::shared_ptr<const cpp::ProgramDesc> program_desc) {
param_.program_desc = program_desc;
}
std::shared_ptr<const cpp::ProgramDesc> GetProgramDesc() {
return param_.program_desc;
}
private: private:
mutable WhileParam param_; mutable WhileParam param_;
cpp::BlockDesc *sub_block_;
}; };
} // namespace operators } // namespace operators
......
if(LITE_WITH_ARM)
lite_cc_test(test_transformer_with_mask_fp32_arm SRCS test_transformer_with_mask_fp32_arm.cc
DEPS ${lite_model_test_DEPS} paddle_api_full
ARM_DEPS ${arm_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/transformer_with_mask_fp32 SERIAL)
if(WITH_TESTING)
add_dependencies(test_transformer_with_mask_fp32_arm extern_lite_download_transformer_with_mask_fp32_tar_gz)
endif()
endif()
if(LITE_WITH_XPU) if(LITE_WITH_XPU)
lite_cc_test(test_resnet50_lite_xpu SRCS test_resnet50_lite_xpu.cc lite_cc_test(test_resnet50_lite_xpu SRCS test_resnet50_lite_xpu.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
......
...@@ -26,156 +26,171 @@ ...@@ -26,156 +26,171 @@
DEFINE_bool(perf, false, "perf?"); DEFINE_bool(perf, false, "perf?");
DEFINE_string(perf_input, "perf_input", "perf_input"); DEFINE_string(perf_input, "perf_input", "perf_input");
DEFINE_int32(perf_batch_size, 40, "perf_batch_size");
DEFINE_bool(use_xpu, true, "use_xpu?");
DEFINE_int32(perf_dev, 0, "perf_dev");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
std::vector<int64_t> input0; class SampleReader {
std::vector<uint64_t> input0_lod = {0}; public:
std::vector<int64_t> input1; std::vector<std::vector<int64_t>> data;
std::vector<uint64_t> input1_lod = {0}; std::vector<std::vector<uint64_t>> lod;
std::vector<int64_t> input2;
std::vector<uint64_t> input2_lod = {0};
std::vector<int64_t> input3;
std::vector<uint64_t> input3_lod = {0};
std::vector<int64_t> input4;
std::vector<uint64_t> input4_lod = {0};
std::vector<int64_t> input5;
std::vector<uint64_t> input5_lod = {0};
void ParseInput() { void Read() {
std::string raw_input = std::string raw_input =
"0 1;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "0 1;125 584 142 2114 197;125 756226 756913 855693 760836;125 584 142 "
"760166;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 145 " "2114 197 10 2899;125 756226 756913 855693 760836 10 750793;125 584 "
"10251 839 5 1779 1729 1779 1729 18 2707 6 2707 20 4742 4937 432 6 " "142 2114 197 10 2899 2 825 32 18499 125 584 295 2114 197 2114 2730 6 "
"3869;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 767614 " "15 32 18499 125 584 142 295 2114 1423 21 2 334 863 5122 197 974 21 "
"767614 1020808 769579 793958 793958 1050488 911898 751332 751332 750336 " "295 619 25 2114 1755 2701 197 15 216 23 18499 125 584 142 599 3228 23 "
"750799 750336 751575 751575 751544 751735 751397 751365 751512 751512 " "2 5122 1917 804 5 2114 197 1236 3 2114 1403 15 3886 1080 23 1150 125 "
"753011 751562;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 " "475 23 2998 23;125 756226 756913 855693 760836 10 750793 2 825 750355 "
"145 10251 839 2 1211 3 3719 720 1540 145 10251 839 9405 4315 5998 4 2 " "18499 881680 756226 295 765124 760836 2114 872813 754265 15 32 18499 "
"600 373 41 3719 428 52 44 10251 4302 1319 7 12 2 768 6 918 6 841 870 8 " "881680 756226 756913 761251 765124 752843 766823 2 334 759834 5122 "
"843 8 271;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 " "774643 758458 21 295 755114 25 1148365 1755 2701 197 15 216 23 18499 "
"767614 767614 1020808 769579 793958 793958 1050488 911898 2 773899 " "881680 756226 756913 826848 3228 23 2 5122 831009 804 752371 2114 "
"773899 3719 1118420 1118420 1050488 1050488 911898 9405 4315 5998 4 2 " "760836 1236 3 2114 910393 15 3886 1080 23 877375 752137 761034 792123 "
"785435 785435 41 3719 760166 760166 44 10251 4302 1319 750118 750118 2 " "2998 23;1;1;\n"
"750465 750465 750274 750398 750233 751252 751252 753447 752830 753112;\n" "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;121 28 1054 "
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "1459 125 72 32 2321 531 125 295 584 142 2114 197 14 477 30 121;121 28 "
"760166;2109 2467 1805 227 3719 428 52 18 1102 10327 252 20 6 242 78 6 " "764114 1459 753052 750694 750001 886192 750435 752179 295 584 756913 "
"532 78;2109 2467 1805 1245431 1245431 760166 760166 18 1035176 1035176 " "855693 760836 14 477 30 753504;121 28 1054 1459 125 72 32 2321 531 "
"764393 764393 752116 242 750370 750370 752081 751247;2109 2467 1805 227 " "125 295 584 142 2114 197 2 121 28 1054 1459 125 72 32 2321 531 125 "
"3719 428 52 18 1102 10327 252 20 2 145 242 1050 252 3582 2212;2109 2467 " "295 584 142 4 263 2114 197 43 95 863 2114 323 20 142 626 11 2 45 10 "
"1805 1245431 1245431 760166 760166 18 1035176 1035176 764393 764393 2 " "45 58 142 65 918 741 2114 197 764 3 5122 26 51 1266 2037 295 222 1121 "
"871717 871717 757921 757921 3582 2212;\n" "4491 3 545 4338 11 2 5122 26 495 3 142 3444 3249 2114 197 3 626 4 "
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "2794;121 28 764114 1459 753052 750694 750001 886192 750435 752179 295 "
"760166;145 10251 839 76 31 1337 823 7506 567 65 170 8 21293 3719 5 43 " "584 756913 855693 760836 2 121 28 764114 1459 753052 750694 750001 "
"394 743 42;1050488 1050488 911898 750016 750016 1337 823 7506 762617 " "886192 750435 752179 295 584 756913 4 750885 2114 760836 43 750030 "
"762617 866652 8 21293 3719 5 43 914758 914758 757202;145 10251 839 76 " "754302 2114 323 822131 142 626 769001 2 45 750128 750324 58 142 "
"31 1337 823 7506 567 65 170 8 21293 3719 2 17580 30 523324 3 10251 4104 " "1147454 918 910829 2114 760836 841946 767340 5122 779102 51 1266 2037 "
"281 3 8511 3719 2217 3 13 226 3083 4 11251 1606 357 9 2 145 10251 839 " "756461 222 752031 942669 1139389 780275 4338 830597 2 5122 779102 495 "
"76 31 1337 823 7506 567 65 170 2 7506 2445 8 145 10251 839 528 839 " "761418 142 3444 852932 2114 760836 3 760162 757966 751127;121 295 "
"19670 6538;1050488 1050488 911898 750016 750016 1337 823 7506 762617 " "5593 142 2114 197;121 295 5593 925208 2114 760836;\n"
"762617 866652 8 21293 3719 2 816626 816626 523324 3 1181698 1181698 " "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;207 125 584 "
"751656 780821 1063148 3719 2217 3 752498 752498 831323 753602 11251 " "142 2114 1423 14 5283 1745 73;207 752276 756226 756913 855693 752843 "
"1606 357 9 2 1050488 1050488 911898 750016 750016 1337 823 7506 762617 " "14 5283 781651 786597;6109 18807 142 5 64 5283 1745 73 3690 1060 3626 "
"762617 866652 2 7506 753045 753045 756756 1050488 911898 528 839 19670 " "4 716 51 1030 2114 197 4 428 936 9066 10 10 10 2 207 125 584 142 2114 "
"6538;\n" "1423 2 15329 2114 197 5669 401 318 285 953 4 2114 197 2285 7 1783 11 "
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "2 5122 197 14017 584;6109 18807 142 5 755319 5283 781651 786597 3690 "
"760166;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 2899 " "1060 3626 4 716 910478 1030 2114 760836 4 750323 936 9066 10 750002 "
"229 10 10 10;1050488 1050488 911898 807966 750273 1035176 1035176 " "750002 2 207 752276 756226 756913 855693 752843 2 15329 2114 760836 "
"1237875 41 3719 760166 760166 753645 753645 750273 2899 229 750001 " "5669 401 318 757541 750261 4 2114 760836 2285 7 757639 11 2 5122 "
"750001 750001;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 " "774643 14017 584;125 584 142 1745 5122;125 756226 756913 1745 "
"2899 229 10 10 10 2 1177 8 145 10251 839 99 4 1102 10327 2196 41 3719 " "755836;\n"
"428 52 44 99 4 2 101 8 1922 17 2184 2 1154 1922 72 1198 1266 " "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;149 396 778 "
"4516;1050488 1050488 911898 807966 750273 1035176 1035176 1237875 41 " "584 142 295 2114 1423 14 64 125 584 73 21 36670 5834 10 211 25;149 "
"3719 760166 760166 753645 753645 750273 2899 229 750001 750001 750001 2 " "751876 1048872 584 756913 761251 765124 752843 14 64 125 756226 73 "
"750257 750257 756756 1050488 911898 807966 750273 1035176 1035176 " "944567 36670 5834 10 750012 753240;101 10 2114 197 3 946 2 149 396 "
"1237875 41 3719 760166 760166 753645 753645 750273 2 764513 764513 " "778 584 142 295 2114 1423 2 2610 6 1444 111 2114 948 72 32 21 15 494 "
"851213 851213 854628 2 753018 753018 754317 753328 754085 754070;\n" "25 4 2114 197 5669 1145 2 148 295 149 396 778 584 142 295 21 22853 41 "
"0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " "348 619 25 366 5305 2114 807 4 1115 381 1955 2114 11;101 751178 2114 "
"760166;73 5347 112 8 145 10251 839 262 169 22729 3719 6 743 6 339 1156 " "760836 3 946 2 149 751876 1048872 584 756913 761251 765124 752843 2 "
"78 136 399 693 128 571;776150 776150 112 756756 756756 1050488 911898 " "2610 753567 775165 750899 972788 948 750125 750001 751875 15 494 25 4 "
"791355 791355 22729 3719 6 758277 758277 750137 750234 750241 750178 " "2114 760836 5669 1145 2 148 808886 982157 751876 1048872 584 756913 "
"750055 750216 750212 750049;73 5347 112 8 145 10251 839 262 169 22729 " "761251 790772 22853 41 348 619 25 366 894206 2114 1008440 4 753953 "
"3719 2 588 415 549 415 115 23;776150 776150 112 756756 756756 1050488 " "381 851474 765868 11;149 396 778 584 142 295 2 149 396 354 778 584 "
"911898 791355 791355 22729 3719 2 750221 750221 750262 750277 750277 " "142 1333 2 584 778 295 5122 2 149 396 778 584 3609 2 149 396 64478 "
"750261;"; "816 14246 1423 2 149 396 584 32 127 19 3609 2 149 396 584 73 2 149 "
auto raw_lines = Split(raw_input, "\n"); "396 584 778 295 2285 142 4922 323 2 149 396 584 2114 2 149 396 253 "
for (auto& raw_line : raw_lines) { "584 2114 197;149 751876 1048872 584 756913 761251 2 149 751876 756286 "
auto inputx = Split(raw_line, ";"); "767182 584 756913 1333 2 584 778 897778 941364 2 149 751876 1048872 "
for (size_t i = 1; i < inputx.size(); ++i) { "584 1102835 2 149 751876 64478 816 14246 912094 2 149 751876 584 "
auto tokens = Split(inputx[i], " "); "773547 127 750771 791456 2 149 751876 584 73 2 149 751876 584 778 "
static std::vector<int64_t>* const input_array[] = { "897778 2285 751493 791984 323 2 149 751876 584 2114 2 149 751876 "
&input0, &input0, &input1, &input2, &input3, &input4, &input5}; "808443 835481 2114 760836;\n"
static std::vector<uint64_t>* const lod_array[] = {&input0_lod, "0 0;125 584 142 2114 197;125 756226 756913 855693 760836;125 584 545 "
&input0_lod, "149 14 125 584;125 756226 545 874302 14 125 756226;2204 25 30 1692 "
&input1_lod, "1770 6534 295 125 584 72 32 1346 4 2698 2114 197 11 2 4235 4301 240 "
&input2_lod, "295 125 584 72 32 21 6708 15 56974 494 25 1030 2114 197 110 804 495 "
&input3_lod, "611 2 221 759 341 6 5283 1745 73 71 2114 1423 71 125 584 545 149 149 "
&input4_lod, "2 505 345 58 125 584 65 3486 2114 295 4 45 786 196 6604 6086;2204 25 "
&input5_lod}; "30 797189 1770 1191824 295 752782 756226 751697 750001 1346 4 2698 "
for (auto token : tokens) { "2114 760836 765158 2 4235 4301 240 753859 752782 756226 751697 750001 "
input_array[i]->push_back((int64_t)atoi(token.c_str())); "751875 6708 15 56974 494 25 1030 2114 760836 777607 762850 966521 611 "
} "2 221 752565 750130 750084 910219 781651 786597 71 2114 752843 71 125 "
lod_array[i]->push_back((uint64_t)tokens.size() + "756226 545 874302 149 2 505 825657 782848 125 756226 65 3486 2114 "
(*lod_array[i])[lod_array[i]->size() - 1]); "760669 4 45 755747 758903 6604 6086;125 584 2114 2 125 584 2114 1423 "
} "2 125 584 2114 149 2 149 584 1745 5122 725 2 2114 125 584 2 125 584 "
} "2114 2 2621 584 2114 2 527 37 2754 130 170 1013 494 887 240 2 4521 "
return; "11111 586 2321 531 125 584 142 1360 816 2842 1423 2 125 584 2114;125 "
} "756226 2114 2 125 756226 2114 752843 2 125 756226 2114 783644 2 149 "
"760183 1745 755836 725 2 2114 125 756226 2 125 756226 2114 2 2621 "
"932600 2114 2 527 751304 869964 754462 170 1013 750719 778287 774620 "
"2 4521 11111 586 2321 750435 752179 756226 756913 1360 764399 2842 "
"1423 2 125 756226 2114;\n"
"0 0;125 584 142 2114 197;125 756226 756913 855693 760836;207 584 142 "
"2114 197 4 207 584 142 2114 197 674 14 240 4328 14 4328 767;207 "
"1237071 756913 855693 760836 4 207 1237071 756913 855693 760836 674 "
"14 240 755573 14 4328 795065;207 584 142 2114 197 2 325 71 71 207 584 "
"142 2114 197 2 876 125 140 2114 197 2 207 584 142 2114 197 674 1210 "
"239 4328 767 268 1349 485 28 4389 504 3 941 57 1419 1978 11;207 "
"1237071 756913 855693 760836 2 325 71 71 207 1237071 756913 855693 "
"760836 2 876 125 750977 1250790 760836 2 207 1237071 756913 855693 "
"760836 674 814792 755820 812174 795065 818859 817155 816597 761001 "
"774461 780904 820475 1109800 790141 790459 780324 770390;584 142 295 "
"2114 232 2 207 584 2114 197 2 584 142 295 2114 232 2 584 142 512 2114 "
"197;584 756913 761251 765124 1006359 2 207 1237071 2114 760836 2 584 "
"756913 761251 765124 1006359 2 584 756913 879930 2114 760836;";
class MmdnnReader { auto lines = Split(raw_input, "\n");
std::ifstream ifs; for (auto& line : lines) {
std::vector<std::string> StringSplit(const std::string& in, auto split1 = Split(line, ";");
const std::string& delim) { if (data.size() == 0) {
std::vector<std::string> ret; for (size_t i = 1; i < split1.size(); ++i) {
if (in == "") { data.push_back(std::vector<int64_t>());
return ret; lod.push_back({0});
} }
auto begpos = in.find_first_not_of(delim);
while (begpos != std::string::npos) {
auto endpos = in.find_first_of(delim, begpos);
if (endpos == std::string::npos) {
endpos = in.size();
} }
std::string ssubstr = in.substr(begpos, endpos - begpos);
ret.push_back(ssubstr); for (size_t i = 1; i < split1.size(); ++i) {
begpos = endpos + 1; auto split2 = Split(split1[i], " ");
if (endpos >= (in.size() - 1)) { if (split2.size() == 0) {
break; split2.push_back("1280000");
}
for (auto e : split2) {
data[i - 1].push_back(std::stoi(e.c_str(), nullptr, 0));
}
lod[i - 1].push_back(lod[i - 1].back() + split2.size());
} }
} }
return ret;
} }
};
class FileReader {
std::ifstream ifs;
public: public:
std::vector<int64_t> data[6]; std::vector<std::vector<int64_t>> data;
std::vector<uint64_t> lod[6]; std::vector<std::vector<uint64_t>> lod;
void Init(std::string file_name) { ifs.open(file_name); } void Init(std::string file_name) { ifs.open(file_name); }
int Read(int maxline) { int Read(int maxline) {
for (int i = 0; i < 6; i++) { data.clear();
data[i].clear(); lod.clear();
}
for (int i = 0; i < 6; i++) {
lod[i].clear();
lod[i].push_back(0);
}
std::string line; std::string line;
int cnt = 0; int cnt = 0;
while (cnt < maxline && getline(ifs, line)) { while (cnt < maxline && getline(ifs, line)) {
std::vector<std::string> split1 = StringSplit(line, ";"); std::vector<std::string> split1 = Split(line, ";");
for (int i = 1; i < 7; i++) { if (data.size() == 0) {
std::vector<std::string> split2 = StringSplit(split1[i], " "); for (size_t i = 1; i < split1.size(); ++i) {
data.push_back(std::vector<int64_t>());
lod.push_back({0});
}
}
for (size_t i = 1; i < split1.size(); i++) {
std::vector<std::string> split2 = Split(split1[i], " ");
if (split2.size() == 0) { if (split2.size() == 0) {
split2.push_back("1280000"); split2.push_back("1280000");
} }
for (size_t j = 0; j < split2.size(); j++) { for (size_t j = 0; j < split2.size(); j++) {
data[i - 1].push_back(std::stoi(split2[j].c_str(), nullptr, 0)); data[i - 1].push_back(std::stoi(split2[j].c_str(), nullptr, 0));
} }
// if (i % 2 == 1) {
// lod[i / 2].push_back(lod[i / 2].back() + split2.size());
//}
lod[i - 1].push_back(lod[i - 1].back() + split2.size()); lod[i - 1].push_back(lod[i - 1].back() + split2.size());
} }
cnt++; cnt++;
...@@ -186,36 +201,47 @@ class MmdnnReader { ...@@ -186,36 +201,47 @@ class MmdnnReader {
TEST(MMDNN, test_mmdnn_lite_xpu) { TEST(MMDNN, test_mmdnn_lite_xpu) {
lite_api::CxxConfig config; lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir); // config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, config.set_model_file(FLAGS_model_dir + "/__model__");
lite_api::Place{TARGET(kXPU), PRECISION(kInt64)}, config.set_param_file(FLAGS_model_dir + "/__param__");
lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, config.set_xpu_dev_per_thread(FLAGS_perf_dev);
lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, if (FLAGS_use_xpu) {
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); config.set_valid_places(
{lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kXPU), PRECISION(kInt64)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
} else {
config.set_valid_places(
{lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
}
config.set_xpu_workspace_l3_size_per_thread(); config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
if (FLAGS_perf) { if (FLAGS_perf) {
MmdnnReader reader; FileReader file_reader;
reader.Init(FLAGS_perf_input); file_reader.Init(FLAGS_perf_input);
int UB_batch = 40; // upper bound of batch int UB_batch = FLAGS_perf_batch_size; // upper bound of batch
int iter = 0; int iter = 0;
double tsc_sum = 0; double tsc_sum = 0;
while (true) { while (true) {
int batch = reader.Read(UB_batch); int batch = file_reader.Read(UB_batch);
if (batch <= 0) { if (batch <= 0) {
break; break;
} }
++iter; ++iter;
for (int i = 0; i < 6; ++i) { for (size_t i = 0; i < file_reader.data.size(); ++i) {
auto input_x = predictor->GetInput(i); auto input_x = predictor->GetInput(i);
input_x->Resize({(int64_t)reader.data[i].size(), 1}); input_x->Resize({(int64_t)file_reader.data[i].size(), 1});
input_x->SetLoD({reader.lod[i]}); input_x->SetLoD({file_reader.lod[i]});
auto* data_x = input_x->mutable_data<int64_t>(); auto* data_x = input_x->mutable_data<int64_t>();
memcpy(data_x, memcpy(data_x,
reader.data[i].data(), file_reader.data[i].data(),
reader.data[i].size() * sizeof(int64_t)); file_reader.data[i].size() * sizeof(int64_t));
} }
auto start = GetCurrentUS(); auto start = GetCurrentUS();
...@@ -232,55 +258,17 @@ TEST(MMDNN, test_mmdnn_lite_xpu) { ...@@ -232,55 +258,17 @@ TEST(MMDNN, test_mmdnn_lite_xpu) {
return; return;
} }
ParseInput(); SampleReader sample_reader;
sample_reader.Read();
{ for (size_t i = 0; i < sample_reader.data.size(); ++i) {
std::vector<int64_t> input0_shape{(int64_t)input0.size(), 1}; auto input_x = predictor->GetInput(i);
auto input_tensor0 = predictor->GetInput(0); input_x->Resize({(int64_t)sample_reader.data[i].size(), 1});
input_tensor0->Resize(input0_shape); input_x->SetLoD({sample_reader.lod[i]});
input_tensor0->SetLoD({input0_lod}); auto* data_x = input_x->mutable_data<int64_t>();
auto* data0 = input_tensor0->mutable_data<int64_t>(); memcpy(data_x,
memcpy(data0, input0.data(), sizeof(int64_t) * input0.size()); sample_reader.data[i].data(),
} sample_reader.data[i].size() * sizeof(int64_t));
{
std::vector<int64_t> input1_shape{(int64_t)input1.size(), 1};
auto input_tensor1 = predictor->GetInput(1);
input_tensor1->Resize(input1_shape);
input_tensor1->SetLoD({input1_lod});
auto* data1 = input_tensor1->mutable_data<int64_t>();
memcpy(data1, input1.data(), sizeof(int64_t) * input1.size());
}
{
std::vector<int64_t> input2_shape{(int64_t)input2.size(), 1};
auto input_tensor2 = predictor->GetInput(2);
input_tensor2->Resize(input2_shape);
input_tensor2->SetLoD({input2_lod});
auto* data2 = input_tensor2->mutable_data<int64_t>();
memcpy(data2, input2.data(), sizeof(int64_t) * input2.size());
}
{
std::vector<int64_t> input3_shape{(int64_t)input3.size(), 1};
auto input_tensor3 = predictor->GetInput(3);
input_tensor3->Resize(input3_shape);
input_tensor3->SetLoD({input3_lod});
auto* data3 = input_tensor3->mutable_data<int64_t>();
memcpy(data3, input3.data(), sizeof(int64_t) * input3.size());
}
{
std::vector<int64_t> input4_shape{(int64_t)input4.size(), 1};
auto input_tensor4 = predictor->GetInput(4);
input_tensor4->Resize(input4_shape);
input_tensor4->SetLoD({input4_lod});
auto* data4 = input_tensor4->mutable_data<int64_t>();
memcpy(data4, input4.data(), sizeof(int64_t) * input4.size());
}
{
std::vector<int64_t> input5_shape{(int64_t)input5.size(), 1};
auto input_tensor5 = predictor->GetInput(5);
input_tensor5->Resize(input5_shape);
input_tensor5->SetLoD({input5_lod});
auto* data5 = input_tensor5->mutable_data<int64_t>();
memcpy(data5, input5.data(), sizeof(int64_t) * input5.size());
} }
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
template <typename T>
void SetTensorData(const std::vector<T> &data,
const std::vector<int64_t> &shape,
paddle::lite_api::Tensor *tensor,
const std::vector<std::vector<uint64_t>> &lod = {}) {
tensor->Resize(shape);
tensor->SetLoD(lod);
std::copy(data.begin(), data.end(), tensor->mutable_data<T>());
}
void PrepareInputData(
const std::shared_ptr<paddle::lite_api::PaddlePredictor> &predictor,
std::vector<int64_t> src_word_data,
int max_seq_len = 16, // padding
int max_out_len = 8,
int bos_idx = 0,
int eos_idx = 1,
int n_head = 8) {
// src_word
auto src_word = predictor->GetInput(0);
int seq_len = src_word_data.size();
for (int i = seq_len; i < max_seq_len; i++) {
src_word_data.push_back(eos_idx);
}
std::vector<int64_t> src_word_shape{
1, static_cast<int64_t>(src_word_data.size())};
SetTensorData<int64_t>(src_word_data, src_word_shape, src_word.get());
// src_pos
auto src_pos = predictor->GetInput(1);
std::vector<int64_t> src_pos_data(src_word_data.size());
std::iota(src_pos_data.begin(), src_pos_data.end(), 0);
std::vector<int64_t> src_pos_shape{1,
static_cast<int64_t>(src_pos_data.size())};
SetTensorData<int64_t>(src_pos_data, src_pos_shape, src_pos.get());
// src_slf_attn_bias
auto src_slf_attn_bias = predictor->GetInput(2);
std::vector<float> src_slf_attn_bias_data(1 * n_head * src_word_data.size() *
src_word_data.size());
int offset = 0;
for (int j = 0; j < 1 * n_head * src_word_data.size(); j++) {
for (int i = 0; i < seq_len; i++) {
src_slf_attn_bias_data[offset++] = 0.0f;
}
for (int i = seq_len; i < src_word_data.size(); i++) {
src_slf_attn_bias_data[offset++] = -1e9f;
}
}
std::vector<int64_t> src_slf_attn_bias_shape{
1,
n_head,
static_cast<int64_t>(src_word_data.size()),
static_cast<int64_t>(src_word_data.size())};
SetTensorData<float>(
src_slf_attn_bias_data, src_slf_attn_bias_shape, src_slf_attn_bias.get());
// trg_word
auto trg_word = predictor->GetInput(3);
std::vector<int64_t> trg_word_data(2, 0);
std::vector<int64_t> trg_word_shape{2, 1};
std::vector<uint64_t> lod_level_0{0, 2};
std::vector<uint64_t> lod_level_1{0, 1, 2};
std::vector<std::vector<uint64_t>> trg_word_lod(2);
trg_word_lod[0] = lod_level_0;
trg_word_lod[1] = lod_level_1;
SetTensorData<int64_t>(
trg_word_data, trg_word_shape, trg_word.get(), trg_word_lod);
// init_score
auto init_score = predictor->GetInput(4);
std::vector<float> init_score_data(2);
init_score_data[0] = 0;
init_score_data[1] = -1e9f;
std::vector<int64_t> init_score_shape{2, 1};
std::vector<std::vector<uint64_t>> init_score_lod(trg_word_lod);
SetTensorData<float>(
init_score_data, init_score_shape, init_score.get(), init_score_lod);
// init_idx
auto init_idx = predictor->GetInput(5);
std::vector<int32_t> init_idx_data(2, 0);
std::vector<int64_t> init_idx_shape{2};
SetTensorData<int32_t>(init_idx_data, init_idx_shape, init_idx.get());
// trg_slf_attn_bias
auto trg_slf_attn_bias = predictor->GetInput(6);
std::vector<float> trg_slf_attn_bias_data(max_out_len * n_head * 1 *
max_out_len);
offset = 0;
for (int k = 0; k < max_out_len; k++) {
for (int j = 0; j < n_head; j++) {
for (int i = 0; i < max_out_len; i++) {
trg_slf_attn_bias_data[offset++] = (i <= k) ? 0.0f : -1e9f;
}
}
}
std::vector<int64_t> trg_slf_attn_bias_shape{
max_out_len, n_head, 1, max_out_len};
SetTensorData<float>(
trg_slf_attn_bias_data, trg_slf_attn_bias_shape, trg_slf_attn_bias.get());
// trg_src_attn_bias
auto trg_src_attn_bias = predictor->GetInput(7);
std::vector<float> trg_src_attn_bias_data(1 * n_head * 1 *
src_word_data.size());
offset = 0;
for (int j = 0; j < 1 * n_head * 1; j++) {
for (int i = 0; i < seq_len; i++) {
trg_src_attn_bias_data[offset++] = 0.0f;
}
for (int i = seq_len; i < src_word_data.size(); i++) {
trg_src_attn_bias_data[offset++] = -1e9f;
}
}
std::vector<int64_t> trg_src_attn_bias_shape{
1, n_head, 1, static_cast<int64_t>(src_word_data.size())};
SetTensorData<float>(
trg_src_attn_bias_data, trg_src_attn_bias_shape, trg_src_attn_bias.get());
// kv_padding_selection
auto kv_padding_selection = predictor->GetInput(8);
std::vector<float> kv_padding_selection_data(max_out_len * n_head *
max_out_len * 1);
offset = 0;
for (int k = 0; k < max_out_len; k++) {
for (int j = 0; j < n_head; j++) {
for (int i = 0; i < max_out_len; i++) {
kv_padding_selection_data[offset++] = (i == k) ? 1.0f : 0.0f;
}
}
}
std::vector<int64_t> kv_padding_selection_shape{
max_out_len, n_head, max_out_len, 1};
SetTensorData<float>(kv_padding_selection_data,
kv_padding_selection_shape,
kv_padding_selection.get());
}
void CheckOutputData(
const std::shared_ptr<paddle::lite_api::PaddlePredictor> &predictor,
const std::vector<int64_t> &ref_seq_ids_data,
const std::vector<float> &ref_seq_scores_data) {
// seq_ids
auto seq_ids = predictor->GetOutput(0);
auto seq_ids_shape = seq_ids->shape();
auto seq_ids_size = std::accumulate(seq_ids_shape.begin(),
seq_ids_shape.end(),
1,
std::multiplies<int64_t>());
ASSERT_EQ(seq_ids_size, ref_seq_ids_data.size());
auto *seq_ids_data = seq_ids->data<int64_t>();
for (size_t i = 0; i < seq_ids_size; i++) {
EXPECT_EQ(seq_ids_data[i], ref_seq_ids_data[i]);
}
// seq_scores
auto seq_scores = predictor->GetOutput(1);
auto seq_scores_shape = seq_scores->shape();
auto seq_scores_size = std::accumulate(seq_scores_shape.begin(),
seq_scores_shape.end(),
1,
std::multiplies<int64_t>());
ASSERT_EQ(seq_scores_size, ref_seq_scores_data.size());
auto *seq_scores_data = seq_scores->data<float>();
for (size_t i = 0; i < seq_scores_size; i++) {
EXPECT_NEAR(seq_scores_data[i], ref_seq_scores_data[i], 1e-5);
}
}
TEST(TransformerWithMask, test_transformer_with_mask_fp32) {
// Save the optimized model by using full api with CxxConfig
lite_api::CxxConfig cxx_config;
cxx_config.set_model_dir(FLAGS_model_dir);
cxx_config.set_valid_places(
{lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
lite_api::Place{TARGET(kARM), PRECISION(kInt64)}});
auto predictor = lite_api::CreatePaddlePredictor(cxx_config);
predictor->SaveOptimizedModel(FLAGS_model_dir + ".nb",
paddle::lite_api::LiteModelType::kNaiveBuffer);
// Load the optimized model and run inference by using light api with
// MobileConfig
paddle::lite_api::MobileConfig mobile_config;
mobile_config.set_model_from_file(FLAGS_model_dir + ".nb");
mobile_config.set_threads(1);
mobile_config.set_power_mode(paddle::lite_api::PowerMode::LITE_POWER_HIGH);
std::vector<std::pair<std::vector<int64_t>,
std::pair<std::vector<int64_t>, std::vector<float>>>>
test_cases = {
{{16, 16, 16, 1},
{{0, 16, 16, 16, 16, 16, 16, 1, 0, 16, 16, 16, 16, 16, 9, 1},
{0.0f,
-0.939061f,
-1.91494f,
-2.94378f,
-4.26457f,
-5.82675f,
-7.45856f,
-7.58065f,
0.0f,
-0.939061f,
-1.91494f,
-2.94378f,
-4.26457f,
-5.82675f,
-8.70994f,
-8.8053f}}},
{{16, 16, 16, 10, 1},
{{0, 6, 53, 11, 1, 0, 6, 53, 56, 4, 1},
{0.0f,
-2.36122f,
-4.1678f,
-6.19764f,
-7.69256f,
0.0f,
-2.36122f,
-4.1678f,
-6.20145f,
-7.66355f,
-8.63024f}}},
{{126, 4, 33, 1},
{{0, 68, 5, 17, 1, 0, 68, 5, 13, 14, 1},
{0.0f,
-0.829941f,
-1.20217f,
-2.23938f,
-2.98262f,
0.0f,
-0.829941f,
-1.20217f,
-2.25051f,
-3.07555f,
-3.57711f}}},
{{126, 4, 33, 99, 1},
{{0, 14, 242, 17, 1, 0, 93, 38, 27, 68, 1},
{0.f,
-1.8504f,
-2.66679f,
-3.09469f,
-3.63227f,
0.0f,
-1.33829f,
-1.41656f,
-3.1333f,
-3.27901f,
-3.88582f}}}};
for (auto &test_case : test_cases) {
PrepareInputData(predictor, test_case.first);
predictor->Run();
CheckOutputData(predictor, test_case.second.first, test_case.second.second);
}
}
} // namespace lite
} // namespace paddle
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
const int MALLOC_ALIGN = 64; const int MALLOC_ALIGN = 16;
void* fast_malloc(size_t size) { void* fast_malloc(size_t size) {
size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; size_t offset = sizeof(void*) + MALLOC_ALIGN - 1;
......
...@@ -120,6 +120,10 @@ bool test_gemm_int8(bool tra, ...@@ -120,6 +120,10 @@ bool test_gemm_int8(bool tra,
auto dc_fp32 = tc_fp32.mutable_data<float>(); auto dc_fp32 = tc_fp32.mutable_data<float>();
auto dc_basic_int8 = tc_basic_int8.mutable_data<int8_t>(); auto dc_basic_int8 = tc_basic_int8.mutable_data<int8_t>();
auto dc_basic_fp32 = tc_basic_fp32.mutable_data<float>(); auto dc_basic_fp32 = tc_basic_fp32.mutable_data<float>();
// set intial input to be 0
memset(reinterpret_cast<char*>(dc_basic_fp32),
0,
tc_basic_fp32.numel() * sizeof(float));
auto dbias = tbias.mutable_data<float>(); auto dbias = tbias.mutable_data<float>();
if (FLAGS_check_result) { if (FLAGS_check_result) {
......
...@@ -108,6 +108,10 @@ bool test_gemv_int8(bool tra, ...@@ -108,6 +108,10 @@ bool test_gemv_int8(bool tra,
auto dc_basic_int8 = tc_basic_int8.mutable_data<int8_t>(); auto dc_basic_int8 = tc_basic_int8.mutable_data<int8_t>();
auto dc_basic_fp32 = tc_basic_fp32.mutable_data<float>(); auto dc_basic_fp32 = tc_basic_fp32.mutable_data<float>();
auto dbias = tbias.mutable_data<float>(); auto dbias = tbias.mutable_data<float>();
// set intial input to be 0
memset(reinterpret_cast<char*>(dc_basic_fp32),
0,
tc_basic_fp32.numel() * sizeof(float));
paddle::lite_api::ActivationType act = paddle::lite_api::ActivationType act =
paddle::lite_api::ActivationType::kIndentity; paddle::lite_api::ActivationType::kIndentity;
......
...@@ -92,6 +92,7 @@ bool test_sgemm_c4( ...@@ -92,6 +92,7 @@ bool test_sgemm_c4(
auto db_c4 = tb_c4.mutable_data<float>(); auto db_c4 = tb_c4.mutable_data<float>();
auto dc_basic = tc_basic.mutable_data<float>(); auto dc_basic = tc_basic.mutable_data<float>();
auto dbias = tbias.mutable_data<float>(); auto dbias = tbias.mutable_data<float>();
memset(reinterpret_cast<char*>(dc_basic), 0, tc_basic.numel());
// trans A, B to c4 // trans A, B to c4
basic_trans_mat_to_c4(da, da_c4, k, m, k, true); basic_trans_mat_to_c4(da, da_c4, k, m, k, true);
......
...@@ -84,6 +84,7 @@ bool test_sgemv(bool tra, ...@@ -84,6 +84,7 @@ bool test_sgemv(bool tra,
auto db = tb.mutable_data<float>(); auto db = tb.mutable_data<float>();
auto dc = tc.mutable_data<float>(); auto dc = tc.mutable_data<float>();
auto dc_basic = tc_basic.mutable_data<float>(); auto dc_basic = tc_basic.mutable_data<float>();
memset(reinterpret_cast<char*>(dc_basic), 0, tc_basic.numel());
auto dbias = tbias.mutable_data<float>(); auto dbias = tbias.mutable_data<float>();
paddle::lite_api::ActivationType act = paddle::lite_api::ActivationType act =
paddle::lite_api::ActivationType::kIndentity; paddle::lite_api::ActivationType::kIndentity;
......
...@@ -22,6 +22,7 @@ OPTMODEL_DIR="" ...@@ -22,6 +22,7 @@ OPTMODEL_DIR=""
BUILD_TAILOR=OFF BUILD_TAILOR=OFF
BUILD_CV=OFF BUILD_CV=OFF
WITH_LOG=ON WITH_LOG=ON
WITH_EXCEPTION=OFF
WITH_PROFILE=OFF WITH_PROFILE=OFF
BUILD_NPU=OFF BUILD_NPU=OFF
NPU_DDK_ROOT="$(pwd)/ai_ddk_lib/" # Download HiAI DDK from https://developer.huawei.com/consumer/cn/hiai/ NPU_DDK_ROOT="$(pwd)/ai_ddk_lib/" # Download HiAI DDK from https://developer.huawei.com/consumer/cn/hiai/
...@@ -126,6 +127,7 @@ function make_tiny_publish_so { ...@@ -126,6 +127,7 @@ function make_tiny_publish_so {
-DLITE_WITH_JAVA=$BUILD_JAVA \ -DLITE_WITH_JAVA=$BUILD_JAVA \
-DLITE_WITH_PYTHON=$BUILD_PYTHON \ -DLITE_WITH_PYTHON=$BUILD_PYTHON \
-DLITE_WITH_LOG=$WITH_LOG \ -DLITE_WITH_LOG=$WITH_LOG \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_ON_TINY_PUBLISH=ON \ -DLITE_ON_TINY_PUBLISH=ON \
-DANDROID_STL_TYPE=$android_stl \ -DANDROID_STL_TYPE=$android_stl \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \
...@@ -181,6 +183,7 @@ function make_opencl { ...@@ -181,6 +183,7 @@ function make_opencl {
-DWITH_TESTING=OFF \ -DWITH_TESTING=OFF \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \
-DLITE_WITH_LOG=$WITH_LOG \ -DLITE_WITH_LOG=$WITH_LOG \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_WITH_CV=$BUILD_CV \ -DLITE_WITH_CV=$BUILD_CV \
-DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 -DARM_TARGET_LANG=$3 -DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 -DARM_TARGET_LANG=$3
...@@ -219,6 +222,7 @@ function make_full_publish_so { ...@@ -219,6 +222,7 @@ function make_full_publish_so {
-DLITE_WITH_JAVA=$BUILD_JAVA \ -DLITE_WITH_JAVA=$BUILD_JAVA \
-DLITE_WITH_PYTHON=$BUILD_PYTHON \ -DLITE_WITH_PYTHON=$BUILD_PYTHON \
-DLITE_WITH_LOG=$WITH_LOG \ -DLITE_WITH_LOG=$WITH_LOG \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_WITH_PROFILE=${WITH_PROFILE} \ -DLITE_WITH_PROFILE=${WITH_PROFILE} \
-DANDROID_STL_TYPE=$android_stl \ -DANDROID_STL_TYPE=$android_stl \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \
...@@ -343,6 +347,8 @@ function make_cuda { ...@@ -343,6 +347,8 @@ function make_cuda {
-DLITE_WITH_STATIC_CUDA=OFF \ -DLITE_WITH_STATIC_CUDA=OFF \
-DLITE_WITH_PYTHON=${BUILD_PYTHON} \ -DLITE_WITH_PYTHON=${BUILD_PYTHON} \
-DLITE_BUILD_EXTRA=ON \ -DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_LOG=${WITH_LOG} \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_WITH_XPU=$BUILD_XPU \ -DLITE_WITH_XPU=$BUILD_XPU \
-DLITE_WITH_XTCL=$BUILD_XTCL \ -DLITE_WITH_XTCL=$BUILD_XTCL \
-DXPU_SDK_ROOT=$XPU_SDK_ROOT -DXPU_SDK_ROOT=$XPU_SDK_ROOT
...@@ -379,6 +385,7 @@ function make_x86 { ...@@ -379,6 +385,7 @@ function make_x86 {
-DLITE_WITH_PYTHON=${BUILD_PYTHON} \ -DLITE_WITH_PYTHON=${BUILD_PYTHON} \
-DLITE_BUILD_EXTRA=ON \ -DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_LOG=${WITH_LOG} \ -DLITE_WITH_LOG=${WITH_LOG} \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_WITH_PROFILE=${WITH_PROFILE} \ -DLITE_WITH_PROFILE=${WITH_PROFILE} \
-DLITE_WITH_XPU=$BUILD_XPU \ -DLITE_WITH_XPU=$BUILD_XPU \
-DLITE_WITH_XTCL=$BUILD_XTCL \ -DLITE_WITH_XTCL=$BUILD_XTCL \
...@@ -409,6 +416,7 @@ function print_usage { ...@@ -409,6 +416,7 @@ function print_usage {
echo echo
echo -e "optional argument:" echo -e "optional argument:"
echo -e "--with_log: (OFF|ON); controls whether to print log information, default is ON" echo -e "--with_log: (OFF|ON); controls whether to print log information, default is ON"
echo -e "--with_exception: (OFF|ON); controls whether to throw the exception when error occurs, default is OFF"
echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)" echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)"
echo -e "--build_train: (OFF|ON); controls whether to publish training operators and kernels, build_train is only for full_publish library now" echo -e "--build_train: (OFF|ON); controls whether to publish training operators and kernels, build_train is only for full_publish library now"
echo -e "--build_python: (OFF|ON); controls whether to publish python api lib (ANDROID and IOS is not supported)" echo -e "--build_python: (OFF|ON); controls whether to publish python api lib (ANDROID and IOS is not supported)"
...@@ -491,6 +499,17 @@ function main { ...@@ -491,6 +499,17 @@ function main {
WITH_LOG="${i#*=}" WITH_LOG="${i#*=}"
shift shift
;; ;;
--with_exception=*)
WITH_EXCEPTION="${i#*=}"
if [[ $WITH_EXCEPTION == "ON" && $ARM_OS=="android" && $ARM_ABI == "armv7" && $ARM_LANG != "clang" ]]; then
set +x
echo
echo -e "error: only clang provide C++ exception handling support for 32-bit ARM."
echo
exit 1
fi
shift
;;
--with_profile=*) --with_profile=*)
WITH_PROFILE="${i#*=}" WITH_PROFILE="${i#*=}"
shift shift
......
...@@ -17,6 +17,8 @@ WITH_JAVA=ON ...@@ -17,6 +17,8 @@ WITH_JAVA=ON
WITH_CV=OFF WITH_CV=OFF
# controls whether to hide log information, default is ON. # controls whether to hide log information, default is ON.
WITH_LOG=ON WITH_LOG=ON
# controls whether to throw the exception when error occurs, default is OFF
WITH_EXCEPTION=OFF
# options of striping lib according to input model. # options of striping lib according to input model.
OPTMODEL_DIR="" OPTMODEL_DIR=""
WITH_STRIP=OFF WITH_STRIP=OFF
...@@ -145,6 +147,7 @@ function make_tiny_publish_so { ...@@ -145,6 +147,7 @@ function make_tiny_publish_so {
local cmake_mutable_options=" local cmake_mutable_options="
-DLITE_BUILD_EXTRA=$WITH_EXTRA \ -DLITE_BUILD_EXTRA=$WITH_EXTRA \
-DLITE_WITH_LOG=$WITH_LOG \ -DLITE_WITH_LOG=$WITH_LOG \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_BUILD_TAILOR=$WITH_STRIP \ -DLITE_BUILD_TAILOR=$WITH_STRIP \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DLITE_WITH_JAVA=$WITH_JAVA \ -DLITE_WITH_JAVA=$WITH_JAVA \
...@@ -194,6 +197,7 @@ function make_full_publish_so { ...@@ -194,6 +197,7 @@ function make_full_publish_so {
local cmake_mutable_options=" local cmake_mutable_options="
-DLITE_BUILD_EXTRA=$WITH_EXTRA \ -DLITE_BUILD_EXTRA=$WITH_EXTRA \
-DLITE_WITH_LOG=$WITH_LOG \ -DLITE_WITH_LOG=$WITH_LOG \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_BUILD_TAILOR=$WITH_STRIP \ -DLITE_BUILD_TAILOR=$WITH_STRIP \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DLITE_WITH_JAVA=$WITH_JAVA \ -DLITE_WITH_JAVA=$WITH_JAVA \
...@@ -237,6 +241,7 @@ function print_usage { ...@@ -237,6 +241,7 @@ function print_usage {
echo -e "| --with_java: (OFF|ON); controls whether to publish java api lib, default is ON |" echo -e "| --with_java: (OFF|ON); controls whether to publish java api lib, default is ON |"
echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |" echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |"
echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |" echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |"
echo -e "| --with_exception: (OFF|ON); controls whether to throw the exception when error occurs, default is OFF |"
echo -e "| --with_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP) |" echo -e "| --with_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP) |"
echo -e "| |" echo -e "| |"
echo -e "| arguments of striping lib according to input model:(armv8, gcc, c++_static) |" echo -e "| arguments of striping lib according to input model:(armv8, gcc, c++_static) |"
...@@ -320,6 +325,18 @@ function main { ...@@ -320,6 +325,18 @@ function main {
WITH_LOG="${i#*=}" WITH_LOG="${i#*=}"
shift shift
;; ;;
# ON or OFF, default OFF
--with_exception=*)
WITH_EXCEPTION="${i#*=}"
if [[ $WITH_EXCEPTION == "ON" && $ARCH == "armv7" && $TOOLCHAIN != "clang" ]]; then
set +x
echo
echo -e "Error: only clang provide C++ exception handling support for 32-bit ARM."
echo
exit 1
fi
shift
;;
# compiling lib which can operate on opencl and cpu. # compiling lib which can operate on opencl and cpu.
--with_opencl=*) --with_opencl=*)
WITH_OPENCL="${i#*=}" WITH_OPENCL="${i#*=}"
......
...@@ -12,6 +12,8 @@ WITH_EXTRA=OFF ...@@ -12,6 +12,8 @@ WITH_EXTRA=OFF
WITH_CV=OFF WITH_CV=OFF
# controls whether to hide log information, default is ON. # controls whether to hide log information, default is ON.
WITH_LOG=ON WITH_LOG=ON
# controls whether to throw the exception when error occurs, default is OFF
WITH_EXCEPTION=OFF
# absolute path of Paddle-Lite. # absolute path of Paddle-Lite.
workspace=$PWD/$(dirname $0)/../../ workspace=$PWD/$(dirname $0)/../../
# options of striping lib according to input model. # options of striping lib according to input model.
...@@ -69,6 +71,7 @@ function make_ios { ...@@ -69,6 +71,7 @@ function make_ios {
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DLITE_WITH_X86=OFF \ -DLITE_WITH_X86=OFF \
-DLITE_WITH_LOG=$WITH_LOG \ -DLITE_WITH_LOG=$WITH_LOG \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_BUILD_TAILOR=$WITH_STRIP \ -DLITE_BUILD_TAILOR=$WITH_STRIP \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DARM_TARGET_ARCH_ABI=$arch \ -DARM_TARGET_ARCH_ABI=$arch \
...@@ -96,6 +99,7 @@ function print_usage { ...@@ -96,6 +99,7 @@ function print_usage {
echo -e "| --arch: (armv8|armv7), default is armv8 |" echo -e "| --arch: (armv8|armv7), default is armv8 |"
echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |" echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |"
echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |" echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |"
echo -e "| --with_exception: (OFF|ON); controls whether to throw the exception when error occurs, default is OFF |"
echo -e "| --with_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP) |" echo -e "| --with_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP) |"
echo -e "| |" echo -e "| |"
echo -e "| arguments of striping lib according to input model:(armv8, gcc, c++_static) |" echo -e "| arguments of striping lib according to input model:(armv8, gcc, c++_static) |"
...@@ -140,6 +144,10 @@ function main { ...@@ -140,6 +144,10 @@ function main {
WITH_LOG="${i#*=}" WITH_LOG="${i#*=}"
shift shift
;; ;;
--with_exception=*)
WITH_EXCEPTION="${i#*=}"
shift
;;
help) help)
print_usage print_usage
exit 0 exit 0
......
...@@ -17,6 +17,8 @@ PY_VERSION="" ...@@ -17,6 +17,8 @@ PY_VERSION=""
WITH_CV=OFF WITH_CV=OFF
# controls whether to print log information, default is ON. # controls whether to print log information, default is ON.
WITH_LOG=ON WITH_LOG=ON
# controls whether to throw the exception when error occurs, default is OFF
WITH_EXCEPTION=OFF
# options of striping lib according to input model. # options of striping lib according to input model.
WITH_STRIP=OFF WITH_STRIP=OFF
OPTMODEL_DIR="" OPTMODEL_DIR=""
...@@ -60,6 +62,7 @@ function init_cmake_mutable_options { ...@@ -60,6 +62,7 @@ function init_cmake_mutable_options {
-DPY_VERSION=$PY_VERSION \ -DPY_VERSION=$PY_VERSION \
-DLITE_WITH_CV=$WITH_CV \ -DLITE_WITH_CV=$WITH_CV \
-DLITE_WITH_LOG=$WITH_LOG \ -DLITE_WITH_LOG=$WITH_LOG \
-DLITE_WITH_EXCEPTION=$WITH_EXCEPTION \
-DLITE_BUILD_TAILOR=$WITH_STRIP \ -DLITE_BUILD_TAILOR=$WITH_STRIP \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \ -DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DLITE_WITH_OPENCL=$WITH_OPENCL \ -DLITE_WITH_OPENCL=$WITH_OPENCL \
...@@ -210,6 +213,7 @@ function print_usage { ...@@ -210,6 +213,7 @@ function print_usage {
echo -e "| --python_version: (2.7|3.5|3.7); controls python version to compile whl, default is None |" echo -e "| --python_version: (2.7|3.5|3.7); controls python version to compile whl, default is None |"
echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |" echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |"
echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |" echo -e "| --with_log: (OFF|ON); controls whether to print log information, default is ON |"
echo -e "| --with_exception: (OFF|ON); controls whether to throw the exception when error occurs, default is OFF |"
echo -e "| |" echo -e "| |"
echo -e "| arguments of striping lib according to input model: |" echo -e "| arguments of striping lib according to input model: |"
echo -e "| ./lite/tools/build_linux.sh --with_strip=ON --opt_model_dir=YourOptimizedModelDir |" echo -e "| ./lite/tools/build_linux.sh --with_strip=ON --opt_model_dir=YourOptimizedModelDir |"
...@@ -280,6 +284,11 @@ function main { ...@@ -280,6 +284,11 @@ function main {
shift shift
;; ;;
# ON or OFF, default OFF # ON or OFF, default OFF
--with_exception=*)
WITH_EXCEPTION="${i#*=}"
shift
;;
# ON or OFF, default OFF
--with_strip=*) --with_strip=*)
BUILD_TAILOR="${i#*=}" BUILD_TAILOR="${i#*=}"
shift shift
......
...@@ -415,7 +415,7 @@ function test_arm_android { ...@@ -415,7 +415,7 @@ function test_arm_android {
echo "test name: ${test_name}" echo "test name: ${test_name}"
adb_work_dir="/data/local/tmp" adb_work_dir="/data/local/tmp"
skip_list=("test_model_parser" "test_mobilenetv1" "test_mobilenetv2" "test_resnet50" "test_inceptionv4" "test_light_api" "test_apis" "test_paddle_api" "test_cxx_api" "test_gen_code" "test_mobilenetv1_int8" "test_subgraph_pass" "test_grid_sampler_image_opencl" "test_lrn_image_opencl" "test_pad2d_image_opencl") skip_list=("test_model_parser" "test_mobilenetv1" "test_mobilenetv2" "test_resnet50" "test_inceptionv4" "test_light_api" "test_apis" "test_paddle_api" "test_cxx_api" "test_gen_code" "test_mobilenetv1_int8" "test_subgraph_pass" "test_grid_sampler_image_opencl" "test_lrn_image_opencl" "test_pad2d_image_opencl" "test_transformer_with_mask_fp32_arm")
for skip_name in ${skip_list[@]} ; do for skip_name in ${skip_list[@]} ; do
[[ $skip_name =~ (^|[[:space:]])$test_name($|[[:space:]]) ]] && echo "skip $test_name" && return [[ $skip_name =~ (^|[[:space:]])$test_name($|[[:space:]]) ]] && echo "skip $test_name" && return
done done
...@@ -1199,6 +1199,7 @@ function main { ...@@ -1199,6 +1199,7 @@ function main {
build_test_arm_subtask_model test_mobilenetv2 mobilenet_v2_relu build_test_arm_subtask_model test_mobilenetv2 mobilenet_v2_relu
build_test_arm_subtask_model test_resnet50 resnet50 build_test_arm_subtask_model test_resnet50 resnet50
build_test_arm_subtask_model test_inceptionv4 inception_v4_simple build_test_arm_subtask_model test_inceptionv4 inception_v4_simple
build_test_arm_subtask_model test_transformer_with_mask_fp32_arm transformer_with_mask_fp32
shift shift
;; ;;
build_test_arm_subtask_armlinux) build_test_arm_subtask_armlinux)
......
...@@ -15,14 +15,23 @@ ...@@ -15,14 +15,23 @@
#pragma once #pragma once
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <string> #include <string>
// Specify the path of configuration file for the subgraph segmentation, an
// example is shown as below:
// op_type:in_var_name_0,in_var_name1:out_var_name_0,out_var_name1
// op_type::out_var_name_0
// op_type:in_var_name_0
// op_type
#define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \ #define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \
"SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE" "SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE"
#define SUBGRAPH_DISABLE_ONLINE_MODE "SUBGRAPH_DISABLE_ONLINE_MODE" // The original weight/local/unused variables in the subblock of the subgraph op
// will be saved only if 'SUBGRAPH_ONLINE_MODE' is set to true(default) during
// the analysis phase, it ensure the ops in the subblock can be converted to the
// target device model online during the execution phase.
#define SUBGRAPH_ONLINE_MODE "SUBGRAPH_ONLINE_MODE"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -57,7 +57,7 @@ static int gettimeofday(struct timeval* tp, void* tzp) { ...@@ -57,7 +57,7 @@ static int gettimeofday(struct timeval* tp, void* tzp) {
#include "lite/utils/replace_stl/stream.h" #include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
#ifdef LITE_WITH_ANDROID #if defined(LITE_WITH_LOG) && defined(LITE_WITH_ANDROID)
#include <android/log.h> #include <android/log.h>
// Android log macors // Android log macors
#define ANDROID_LOG_TAG "Paddle-Lite" #define ANDROID_LOG_TAG "Paddle-Lite"
...@@ -143,8 +143,10 @@ class LogMessage { ...@@ -143,8 +143,10 @@ class LogMessage {
ANDROID_LOG_I(log_stream_.str().c_str()); ANDROID_LOG_I(log_stream_.str().c_str());
} else if (level_ == "W") { } else if (level_ == "W") {
ANDROID_LOG_W(log_stream_.str().c_str()); ANDROID_LOG_W(log_stream_.str().c_str());
} else if (level_ == "F") {
ANDROID_LOG_F(log_stream_.str().c_str());
} else { } else {
fprintf(stderr, "Unsupported log level: %s", level_.c_str()); fprintf(stderr, "Unsupported log level: %s\n", level_.c_str());
assert(false); assert(false);
} }
#endif #endif
...@@ -170,17 +172,25 @@ class LogMessageFatal : public LogMessage { ...@@ -170,17 +172,25 @@ class LogMessageFatal : public LogMessage {
const char* level = "F") const char* level = "F")
: LogMessage(file, func, lineno, level) {} : LogMessage(file, func, lineno, level) {}
~LogMessageFatal() { ~LogMessageFatal()
#ifdef LITE_WITH_EXCEPTION
noexcept(false)
#endif
{
log_stream_ << '\n'; log_stream_ << '\n';
#ifdef LITE_WITH_ANDROID #ifdef LITE_WITH_ANDROID
ANDROID_LOG_F(log_stream_.str().c_str()); ANDROID_LOG_F(log_stream_.str().c_str());
#endif #endif
fprintf(stderr, "%s", log_stream_.str().c_str()); fprintf(stderr, "%s", log_stream_.str().c_str());
#ifdef LITE_WITH_EXCEPTION
throw std::exception();
#else
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
abort(); abort();
#else #else
assert(false); assert(false);
#endif
#endif #endif
} }
}; };
...@@ -237,7 +247,11 @@ class Voidify { ...@@ -237,7 +247,11 @@ class Voidify {
class VoidifyFatal : public Voidify { class VoidifyFatal : public Voidify {
public: public:
#ifdef LITE_WITH_EXCEPTION
~VoidifyFatal() noexcept(false) { throw std::exception(); }
#else
~VoidifyFatal() { assert(false); } ~VoidifyFatal() { assert(false); }
#endif
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册