From ac897177b73108de565cc8b5a8038229252f0218 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Mon, 13 Jul 2020 14:42:48 +0800 Subject: [PATCH] [NPU] enhance cache offline model, test=develop (#3805) (#3931) * [NPU] enhance cache offline model, test=develop --- .gitignore | 3 + cmake/cross_compiling/android.cmake | 6 +- lite/api/cxx_api_impl.cc | 4 + lite/api/light_api_impl.cc | 5 + lite/api/paddle_api.h | 9 + lite/backends/npu/device.cc | 136 ++++-- lite/backends/npu/device.h | 15 +- lite/core/context.cc | 4 + lite/core/context.h | 10 + lite/core/mir/subgraph/subgraph_detector.cc | 107 ++--- lite/core/mir/subgraph/subgraph_pass_test.cc | 2 + lite/kernels/npu/bridges/engine.cc | 113 +++-- lite/kernels/npu/bridges/engine.h | 33 +- lite/kernels/npu/bridges/graph.h | 2 +- lite/kernels/npu/bridges/matmul_op.cc | 8 +- lite/kernels/npu/bridges/utility.h | 31 +- lite/kernels/npu/subgraph_compute.cc | 476 ++++++++++++------- lite/kernels/npu/subgraph_compute.h | 73 ++- lite/utils/env.h | 2 + lite/utils/io.h | 35 ++ lite/utils/md5.h | 104 ++++ lite/utils/string.h | 43 +- 22 files changed, 841 insertions(+), 380 deletions(-) create mode 100644 lite/utils/md5.h diff --git a/.gitignore b/.gitignore index dc0a38edcb..7fe204306f 100644 --- a/.gitignore +++ b/.gitignore @@ -117,3 +117,6 @@ metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models metal/MobileNetDemo/MobileNetDemo/Resources build* + +# hiai libs +ai_ddk_lib* diff --git a/cmake/cross_compiling/android.cmake b/cmake/cross_compiling/android.cmake index 4fc59ccd62..e6193e0bb3 100644 --- a/cmake/cross_compiling/android.cmake +++ b/cmake/cross_compiling/android.cmake @@ -35,7 +35,11 @@ endif() if(NOT DEFINED ANDROID_API_LEVEL) set(ANDROID_API_LEVEL "23") if(ARM_TARGET_ARCH_ABI STREQUAL "armv7") - set(ANDROID_API_LEVEL "22") + if(LITE_WITH_NPU AND NOT LITE_ON_TINY_PUBLISH) + set(ANDROID_API_LEVEL "24") # HIAI DDK depends on android-24 + else() + set(ANDROID_API_LEVEL "22") + endif() endif() endif() diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index d85ed3b644..facea37425 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -70,6 +70,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { raw_predictor_.Build(config, places, passes); mode_ = config.power_mode(); threads_ = config.threads(); +#ifdef LITE_WITH_NPU + Context::SetSubgraphModelCacheDir( + config.subgraph_model_cache_dir()); +#endif #if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) && !defined(__APPLE__) int num_threads = config.x86_math_library_num_threads(); diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index cdf5b7fb06..e76e89af43 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -36,6 +36,11 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { } mode_ = config.power_mode(); threads_ = config.threads(); + +#ifdef LITE_WITH_NPU + Context::SetSubgraphModelCacheDir( + config.subgraph_model_cache_dir()); +#endif } std::unique_ptr LightPredictorImpl::GetInput(int i) { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index a0a7a9b139..aa61e22796 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -118,6 +118,8 @@ class LITE_API ConfigBase { std::string model_dir_; int threads_{1}; PowerMode mode_{LITE_POWER_NO_BIND}; + // to save subgraph model for npu/xpu/... + std::string subgraph_model_cache_dir_{""}; public: explicit ConfigBase(PowerMode mode = LITE_POWER_NO_BIND, int threads = 1); @@ -130,6 +132,13 @@ class LITE_API ConfigBase { // set Thread void set_threads(int threads); int threads() const { return threads_; } + // set subgraph_model_dir + void set_subgraph_model_cache_dir(std::string subgraph_model_cache_dir) { + subgraph_model_cache_dir_ = subgraph_model_cache_dir; + } + const std::string& subgraph_model_cache_dir() const { + return subgraph_model_cache_dir_; + } }; /// CxxConfig is the config for the Full feature predictor. diff --git a/lite/backends/npu/device.cc b/lite/backends/npu/device.cc index 345b239c32..48145daacf 100644 --- a/lite/backends/npu/device.cc +++ b/lite/backends/npu/device.cc @@ -19,52 +19,122 @@ namespace paddle { namespace lite { namespace npu { -std::shared_ptr Device::Build( - const std::string model_name, // NOLINT - std::vector& input_nodes, // NOLINT - std::vector& output_nodes // NOLINT - ) { - VLOG(3) << "[NPU] Build model"; - // Build the HiAI IR graph to the HiAI om model - ge::Graph ir_graph("graph"); - ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes); - ge::Model om_model("model", "model"); - om_model.SetGraph(ir_graph); - domi::HiaiIrBuild ir_build; - domi::ModelBufferData om_model_buf; - if (!ir_build.CreateModelBuff(om_model, om_model_buf)) { - LOG(WARNING) << "[NPU] CreateModelBuff failed!"; - return nullptr; - } - if (!ir_build.BuildIRModel(om_model, om_model_buf)) { - LOG(WARNING) << "[NPU] BuildIRModel failed!"; - ir_build.ReleaseModelBuff(om_model_buf); - return nullptr; - } - +std::shared_ptr Device::Load( + const std::string& model_name, + std::vector* model_buffer, + bool* model_comp) { // Create a HiAI model manager client to load the HiAI om model - std::shared_ptr model_client( - new hiai::AiModelMngerClient()); + auto model_client = std::make_shared(); if (model_client->Init(nullptr) != hiai::AI_SUCCESS) { - LOG(WARNING) << "[NPU] AiModelMngerClient init failed)!"; - ir_build.ReleaseModelBuff(om_model_buf); + LOG(WARNING) << "[NPU] Init hiai model client failed!"; return nullptr; } + // Check HiAI DDK version + const char* ddk_version = model_client->GetVersion(); + if (ddk_version) { + LOG(INFO) << "[NPU] HiAI DDK version: " << ddk_version; + } else { + LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!"; + } + // Check model compatibility auto model_desc = std::make_shared( model_name, freq_level(), framework_type(), model_type(), device_type()); - model_desc->SetModelBuffer(om_model_buf.data, om_model_buf.length); - std::vector> model_descs; - model_descs.push_back(model_desc); + model_desc->SetModelBuffer( + reinterpret_cast(model_buffer->data()), + model_buffer->size()); + if (!*model_comp && + model_client->CheckModelCompatibility(*model_desc, *model_comp) != + hiai::AI_SUCCESS) { + *model_comp = false; + VLOG(3) << "[NPU] model is NOT compatiblitiable, setting model_comp to " + << *model_comp; + } else { + *model_comp = true; + VLOG(3) << "[NPU] model is compatiblitiable, setting model_comp to " + << *model_comp; + } + // Rebuild and write the data of the compatible model to the model buffer + if (!*model_comp) { + std::shared_ptr model_builder = + std::make_shared(model_client); + hiai::MemBuffer* org_model_buffer = model_builder->InputMemBufferCreate( + reinterpret_cast(model_buffer->data()), model_buffer->size()); + if (org_model_buffer) { + std::vector org_model_buffers; + org_model_buffers.push_back(org_model_buffer); + hiai::MemBuffer* new_model_buffer = model_builder->OutputMemBufferCreate( + framework_type(), org_model_buffers); + // VLOG(3) << "[NPU] new model buffer memeory size is " << + // new_model_buffer->GetMemBufferSize(); + if (new_model_buffer) { + uint32_t new_model_size = 0; + if (model_builder->BuildModel(org_model_buffers, + new_model_buffer, + new_model_size) == hiai::AI_SUCCESS) { + // need to change to new_model_size as GetMemBufferSize is not + // correct. + model_buffer->resize(new_model_size); + memcpy(reinterpret_cast(model_buffer->data()), + new_model_buffer->GetMemBufferData(), + new_model_size); + // Reset the model buffer + model_desc->SetModelBuffer( + reinterpret_cast(model_buffer->data()), + model_buffer->size()); + VLOG(3) << "[NPU] Rebuild the compatible model done."; + } else { + LOG(WARNING) << "[NPU] Rebuild the compatible model failed!"; + } + model_builder->MemBufferDestroy(new_model_buffer); + } else { + LOG(WARNING) << "[NPU] OutputMemBufferCreate failed!"; + } + model_builder->MemBufferDestroy(org_model_buffer); + } else { + LOG(WARNING) << "[NPU] InputMemBufferCreate failed!"; + } + } + // Load the compatible model + std::vector> model_descs{ + model_desc}; if (model_client->Load(model_descs) != hiai::AI_SUCCESS) { LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!"; - ir_build.ReleaseModelBuff(om_model_buf); return nullptr; } - ir_build.ReleaseModelBuff(om_model_buf); - VLOG(3) << "[NPU] Build done"; + VLOG(3) << "[NPU] Load model done."; return model_client; } +bool Device::Build(std::vector& input_nodes, // NOLINT + std::vector& output_nodes, // NOLINT + std::vector* model_buffer) { + // Convert the HiAI IR graph to the HiAI om model + ge::Graph ir_graph("graph"); + ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes); + ge::Model om_model("model", "model"); + om_model.SetGraph(ir_graph); + + // Build the HiAI om model, serialize and output it to the om buffer + domi::HiaiIrBuild ir_build; + domi::ModelBufferData om_buffer; + if (!ir_build.CreateModelBuff(om_model, om_buffer)) { + LOG(WARNING) << "[NPU] CreateModelBuff failed!"; + return false; + } + if (!ir_build.BuildIRModel(om_model, om_buffer)) { + LOG(WARNING) << "[NPU] BuildIRModel failed!"; + ir_build.ReleaseModelBuff(om_buffer); + return false; + } + model_buffer->resize(om_buffer.length); + memcpy(reinterpret_cast(model_buffer->data()), + reinterpret_cast(om_buffer.data), + om_buffer.length); + ir_build.ReleaseModelBuff(om_buffer); + VLOG(3) << "[NPU] Build model done."; + return true; +} + } // namespace npu } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/device.h b/lite/backends/npu/device.h index 6733a7f6df..70e371f7c0 100644 --- a/lite/backends/npu/device.h +++ b/lite/backends/npu/device.h @@ -38,13 +38,18 @@ class Device { int model_type() { return model_type_; } int device_type() { return device_type_; } + // Load the HiAI om model from buffer, rebuild the model if it's incompatible + // with the current device, then create a HiAI model manager client(from HiAI + // Server) to run inference + std::shared_ptr Load( + const std::string& model_name, + std::vector* model_buffer, + bool* model_comp); // Build the HiAI IR graph to om model, return HiAI model manager client to // load om model and run inference. - std::shared_ptr Build( - const std::string model_name, // NOLINT - std::vector& input_nodes, // NOLINT - std::vector& output_nodes // NOLINT - ); // NOLINT + bool Build(std::vector& input_nodes, // NOLINT + std::vector& output_nodes, // NOLINT + std::vector* model_buffer); private: int freq_level_{3}; diff --git a/lite/core/context.cc b/lite/core/context.cc index 711c67f8b7..00c5ebf9d0 100644 --- a/lite/core/context.cc +++ b/lite/core/context.cc @@ -17,6 +17,10 @@ namespace paddle { namespace lite { +#ifdef LITE_WITH_NPU +std::string Context::subgraph_model_cache_dir_{""}; // NOLINT +#endif + #ifdef LITE_WITH_XPU thread_local xdnn::Context* Context::_tls_raw_ctx{nullptr}; int Context::_workspace_l3_size_per_thread{0}; diff --git a/lite/core/context.h b/lite/core/context.h index d0c1bd93cc..ad38b2d7f2 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -85,6 +85,16 @@ class Context { NPUContext& operator=(const NPUContext& ctx) {} std::string name() const { return "NPUContext"; } + + static void SetSubgraphModelCacheDir(std::string subgraph_model_cache_dir) { + subgraph_model_cache_dir_ = subgraph_model_cache_dir; + } + static std::string SubgraphModelCacheDir() { + return subgraph_model_cache_dir_; + } + + private: + static std::string subgraph_model_cache_dir_; }; #endif diff --git a/lite/core/mir/subgraph/subgraph_detector.cc b/lite/core/mir/subgraph/subgraph_detector.cc index 6bab454c42..33e1521917 100644 --- a/lite/core/mir/subgraph/subgraph_detector.cc +++ b/lite/core/mir/subgraph/subgraph_detector.cc @@ -426,73 +426,51 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, subgraph_op_desc.SetAttr("sub_block", sub_block_idx); // Extract input and output nodes from the target subgraph - std::unordered_set input_var_nodes; + std::unordered_set idata_var_nodes; std::unordered_set weight_var_nodes; - std::unordered_set output_var_nodes; + std::unordered_set odata_var_nodes; std::unordered_set local_var_nodes; std::unordered_set unused_var_nodes; ExtractInputsOutputs(subgraph_nodes, - &input_var_nodes, + &idata_var_nodes, &weight_var_nodes, - &output_var_nodes, + &odata_var_nodes, &local_var_nodes, &unused_var_nodes); - + // A simplified model without the original weight/local/unused nodes on the + // subgraph ops will be saved only if 'SUBGRAPH_DISABLE_ONLINE_MODE' is set to + // true and Predictor->Run(...), Predictor->Save(...) is called. + std::unordered_set input_var_nodes(idata_var_nodes.begin(), + idata_var_nodes.end()); + std::unordered_set output_var_nodes(odata_var_nodes.begin(), + odata_var_nodes.end()); + if (!GetBoolFromEnv(SUBGRAPH_DISABLE_ONLINE_MODE)) { + input_var_nodes.insert(weight_var_nodes.begin(), weight_var_nodes.end()); + output_var_nodes.insert(local_var_nodes.begin(), local_var_nodes.end()); + output_var_nodes.insert(unused_var_nodes.begin(), unused_var_nodes.end()); + } // Set input and output name mapping which stores the real inputs and // outputs - std::vector input_var_names; - std::vector output_var_names; - for (auto &var_node : input_var_nodes) { - input_var_names.push_back(var_node->AsArg().name); + std::vector idata_var_names; + std::vector odata_var_names; + for (auto &var_node : idata_var_nodes) { + idata_var_names.push_back(var_node->AsArg().name); } - for (auto &var_node : output_var_nodes) { - output_var_names.push_back(var_node->AsArg().name); + for (auto &var_node : odata_var_nodes) { + odata_var_names.push_back(var_node->AsArg().name); } subgraph_op_desc.SetAttr>("input_data_names", - input_var_names); + idata_var_names); subgraph_op_desc.SetAttr>("output_data_names", - output_var_names); - - // Set input/output scale values of input/output var nodes for - // type_precision_cast_pass. - std::vector input_data_scales; - std::vector output_data_scales; - for (auto &var_node : input_var_nodes) { - auto any_op_node = var_node->outlinks.front(); - CHECK(any_op_node->IsStmt()); - auto &any_inst = any_op_node->AsStmt(); - if (any_inst.op_info()->HasAttr("input_scale")) { - input_data_scales.push_back( - any_inst.op_info()->GetAttr("input_scale")); - } - } - for (auto &var_node : output_var_nodes) { - auto any_op_node = var_node->inlinks.front(); - CHECK(any_op_node->IsStmt()); - auto &any_inst = any_op_node->AsStmt(); - if (any_inst.op_info()->HasAttr("output_scale")) { - output_data_scales.push_back( - any_inst.op_info()->GetAttr("output_scale")); - } - } - if (input_data_scales.size() > 0) { - subgraph_op_desc.SetAttr>("input_data_scales", - input_data_scales); - } - if (output_data_scales.size() > 0) { - subgraph_op_desc.SetAttr>("output_data_scales", - output_data_scales); - } - + odata_var_names); // Set all of the inputs and outputs to the target subgraph op // To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram() - for (auto &var_node : weight_var_nodes) { + std::vector input_var_names; + std::vector output_var_names; + for (auto &var_node : input_var_nodes) { input_var_names.push_back(var_node->AsArg().name); } - for (auto &var_node : local_var_nodes) { - output_var_names.push_back(var_node->AsArg().name); - } - for (auto &var_node : unused_var_nodes) { + for (auto &var_node : output_var_nodes) { output_var_names.push_back(var_node->AsArg().name); } subgraph_op_desc.SetInput("Inputs", input_var_names); @@ -509,26 +487,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, for (auto &var_node : input_var_nodes) { IR_NODE_LINK_TO(var_node, subgraph_op_node); } - for (auto &var_node : weight_var_nodes) { - IR_NODE_LINK_TO(var_node, subgraph_op_node); - } for (auto &var_node : output_var_nodes) { IR_OP_VAR_LINK(subgraph_op_node, var_node); } - for (auto &var_node : local_var_nodes) { - IR_OP_VAR_LINK(subgraph_op_node, var_node); - } - for (auto &var_node : unused_var_nodes) { - IR_OP_VAR_LINK(subgraph_op_node, var_node); - } // Remove subgraph nodes and unused var nodes - auto nodes2rm = GetNodes2RM(subgraph_nodes, - {input_var_nodes, - weight_var_nodes, - output_var_nodes, - local_var_nodes, - unused_var_nodes}); + auto nodes2rm = + GetNodes2RM(subgraph_nodes, {input_var_nodes, output_var_nodes}); GraphSafeRemoveNodes(graph, nodes2rm); } @@ -603,7 +568,17 @@ std::unordered_set GetNodes2RM( std::unordered_set nodes2rm(op_nodes.begin(), op_nodes.end()); for (auto &op_node : op_nodes) { for (auto &var_node : op_node->inlinks) { - if (!nodes2rm.count(var_node)) { + bool skip = false; + // skip the var node which is used by any other ops that doesn't belong to + // the subgraph ops. + for (auto &out_op_node : var_node->outlinks) { + if (std::find(op_nodes.begin(), op_nodes.end(), out_op_node) != + op_nodes.end()) { + skip = true; + break; + } + } + if (!skip && !nodes2rm.count(var_node)) { nodes2rm.insert(var_node); } } diff --git a/lite/core/mir/subgraph/subgraph_pass_test.cc b/lite/core/mir/subgraph/subgraph_pass_test.cc index c638793c08..aa582dd933 100644 --- a/lite/core/mir/subgraph/subgraph_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_pass_test.cc @@ -25,6 +25,7 @@ DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model"); DEFINE_string(input_tensor_shape, "1,3,224,224", "shape of input tensors"); DEFINE_string(input_tensor_type, "float32", "data type of input tensors"); DEFINE_string(output_tensor_type, "float32", "data type of output tensors"); +DEFINE_string(subgraph_model_cache_dir, "", "dir of subgraph model cache"); namespace paddle { namespace lite { @@ -132,6 +133,7 @@ std::shared_ptr TestModel( mobile_config.set_model_from_file(optimized_model_dir + ".nb"); mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); mobile_config.set_threads(1); + mobile_config.set_subgraph_model_cache_dir(FLAGS_subgraph_model_cache_dir); predictor = lite_api::CreatePaddlePredictor(mobile_config); FillInputTensors(predictor, input_tensor_shape, input_tensor_type, 1); // Run optimized model diff --git a/lite/kernels/npu/bridges/engine.cc b/lite/kernels/npu/bridges/engine.cc index 6e639a37ba..884ab1acce 100644 --- a/lite/kernels/npu/bridges/engine.cc +++ b/lite/kernels/npu/bridges/engine.cc @@ -15,6 +15,7 @@ #include "lite/kernels/npu/bridges/engine.h" #include #include +#include #include #include "lite/kernels/npu/bridges/registry.h" @@ -22,11 +23,50 @@ namespace paddle { namespace lite { namespace subgraph { -int Engine::BuildDeviceProgram() { return FAILED; } +Engine::Engine(KernelContext *ctx, + int block_idx, + cpp::BlockDesc *block_desc, + const std::vector &input_names, + const std::vector &output_names, + lite::Scope *scope) + : ctx_(ctx), block_idx_(block_idx), block_desc_(block_desc), scope_(scope) { + input_names_ = input_names; + output_names_ = output_names; + // Sort the name of input and output tensors, it's convenient for us to get + // the info of input and output tensors in the same order from the device + // program, because the result of subgraph division may be different but right + // at each call of the subgraph pass. + std::stable_sort(input_names_.begin(), input_names_.end()); + std::stable_sort(output_names_.begin(), output_names_.end()); +} + +bool Engine::Run() { + if (is_first_epoch_) { + PrepareWorkspaceForDeviceProgram(); + is_first_epoch_ = false; + } + if (InputShapeChanged()) { + BuildDeviceProgram(); + } + return LaunchDeviceProgram(); +} -int Engine::LaunchDeviceProgram() { return 0; } +bool Engine::PrepareWorkspaceForOriginProgram() { + origin_idims_.resize(input_names_.size()); + origin_itensors_.resize(input_names_.size()); + for (int i = 0; i < input_names_.size(); i++) { + origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); + CHECK(origin_itensors_[i]); + } + origin_otensors_.resize(output_names_.size()); + for (int i = 0; i < output_names_.size(); i++) { + origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); + CHECK(origin_otensors_[i]); + } + return true; +} -int Engine::BuildOriginProgram() { +bool Engine::BuildOriginProgram() { // TODO(hong19860320) The block_desc need to be divided into subgraphs during // the exection time. But only see them as a subgraph now. origin_program_.clear(); @@ -34,11 +74,14 @@ int Engine::BuildOriginProgram() { auto op_desc = block_desc_->GetOp(op_idx); CHECK(op_desc); std::string op_type = op_desc->Type(); + // Create op and pick up the best kernel auto op = LiteOpRegistry::Global().Create(op_desc->Type()); + CHECK(op) << "no Op found for " << op_type; op->Attach(*op_desc, scope_); std::unique_ptr picked_kernel; if (op_desc->HasAttr(kKernelTypeAttr)) { - // Create op and pick up kernel according to the kKernelTypeAttr attribute + // Create op and pick up the best kernel according to the + // kKernelTypeAttr attribute auto kernel_type = op_desc->GetAttr(kKernelTypeAttr); std::string alias; Place place; @@ -48,12 +91,14 @@ int Engine::BuildOriginProgram() { auto kernels = op->CreateKernels({place}); CHECK_GT(kernels.size(), 0u) << "No kernels found for " << op_type; auto it = std::find_if( - kernels.begin(), kernels.end(), [&](std::unique_ptr& it) { + kernels.begin(), kernels.end(), [&](std::unique_ptr &it) { return it->alias() == alias; }); CHECK(it != kernels.end()); picked_kernel = std::move(*it); } else { + // TODO(hong19860320) add kernel picking according to the type of input + // and output tensors VLOG(3) << "The attr '" << kKernelTypeAttr << "' not found, pick the first kernel for " << op_type; std::vector> kernels; @@ -74,49 +119,41 @@ int Engine::BuildOriginProgram() { } origin_program_.emplace_back(std::move(op), std::move(picked_kernel)); } - return 0; + CHECK(!origin_program_.empty()) << "no instructions"; + return true; } -int Engine::LaunchOriginProgram() { - for (auto& inst : origin_program_) { - auto op_type = inst.op()->op_info()->Type(); - if (op_type == "feed" || op_type == "fetch") continue; - inst.Run(); +bool Engine::LaunchOriginProgram() { + if (origin_program_.empty()) { + BuildOriginProgram(); + } + if (!origin_program_.empty()) { + for (auto &inst : origin_program_) { + auto op_type = inst.op()->op_info()->Type(); + if (op_type == "feed" || op_type == "fetch") continue; + inst.Run(); + } + return true; } - return 0; + return false; } -int Engine::Build() { - // In order to attach all of the ops of the block desc, we need to build the - // original program firstly. - BuildOriginProgram(); - // Run InferShape() of all of ops, and convert Paddle ops to NPU/XPU IR graph - build_device_program_status_ = BuildDeviceProgram(); - return build_device_program_status_; +bool Engine::PrepareWorkspaceForDeviceProgram() { + return PrepareWorkspaceForOriginProgram(); } +bool Engine::BuildDeviceProgram() { return BuildOriginProgram(); } + +bool Engine::LaunchDeviceProgram() { return LaunchOriginProgram(); } + bool Engine::InputShapeChanged() { + bool changed = false; for (size_t i = 0; i < origin_itensors_.size(); i++) { - if (origin_itensors_[i]->dims() != origin_idims_[i]) { - return true; - } - } - return false; -} - -int Engine::Launch() { - // Rebuild device program when the shapes of input tensors have been changed. - if (CHECK_SUCCESS(build_device_program_status_) && - CHECK_REBUILD_WHEN_SHAPE_CHANGED(build_device_program_status_) && - InputShapeChanged()) { - Build(); - } - if (CHECK_FAILED(build_device_program_status_)) { - LaunchOriginProgram(); - } else { - LaunchDeviceProgram(); + auto origin_idim = origin_itensors_[i]->dims().Vectorize(); + changed |= origin_idim != origin_idims_[i]; + origin_idims_[i] = origin_idim; } - return 0; + return changed; } } // namespace subgraph diff --git a/lite/kernels/npu/bridges/engine.h b/lite/kernels/npu/bridges/engine.h index 61a4e12cf3..0f7a58f01a 100644 --- a/lite/kernels/npu/bridges/engine.h +++ b/lite/kernels/npu/bridges/engine.h @@ -33,42 +33,33 @@ class Engine { cpp::BlockDesc *block_desc, const std::vector &input_names, const std::vector &output_names, - lite::Scope *scope) - : ctx_(ctx), - block_idx_(block_idx), - block_desc_(block_desc), - input_names_(input_names), - output_names_(output_names), - scope_(scope) {} + lite::Scope *scope); virtual ~Engine() = default; - virtual int Build(); - virtual int Launch(); + virtual bool Run(); private: Engine(const Engine &) = delete; protected: - virtual int BuildDeviceProgram(); - virtual int LaunchDeviceProgram(); + virtual bool PrepareWorkspaceForOriginProgram(); + virtual bool BuildOriginProgram(); + virtual bool LaunchOriginProgram(); - virtual int BuildOriginProgram(); - virtual int LaunchOriginProgram(); + virtual bool PrepareWorkspaceForDeviceProgram(); + virtual bool BuildDeviceProgram(); + virtual bool LaunchDeviceProgram(); virtual bool InputShapeChanged(); KernelContext *ctx_{nullptr}; - int block_idx_; - cpp::BlockDesc *block_desc_; + int block_idx_{-1}; + cpp::BlockDesc *block_desc_{nullptr}; std::vector input_names_; std::vector output_names_; Scope *scope_{nullptr}; - // SUCCESS: device program build successed. FAILED: device program build - // failed. REBUILD_WHEN_SHAPE_CHANGED: device program build successed but need - // to rebuild when input shape changed. - int build_device_program_status_{0}; - std::vector origin_idims_; - std::vector origin_odims_; + bool is_first_epoch_{true}; + std::vector> origin_idims_; std::vector origin_itensors_; std::vector origin_otensors_; std::vector origin_program_; diff --git a/lite/kernels/npu/bridges/graph.h b/lite/kernels/npu/bridges/graph.h index 67d8a2b1cc..b615460ae7 100644 --- a/lite/kernels/npu/bridges/graph.h +++ b/lite/kernels/npu/bridges/graph.h @@ -19,7 +19,7 @@ #include #include #include -#include "graph/op/all_ops.h" +#include "graph/compatible/all_ops.h" #include "lite/core/op_lite.h" #include "lite/core/tensor.h" diff --git a/lite/kernels/npu/bridges/matmul_op.cc b/lite/kernels/npu/bridges/matmul_op.cc index 32af191689..79ba82d94f 100644 --- a/lite/kernels/npu/bridges/matmul_op.cc +++ b/lite/kernels/npu/bridges/matmul_op.cc @@ -94,10 +94,10 @@ int MatMulConverter(void* ctx, OpLite* op, KernelBase* kernel) { } else { matmul_node = graph->Add(out_name); auto matmul_op = matmul_node->data(); - matmul_op->set_input_x(*x_node->data()); - matmul_op->set_input_y(*y_node->data()); - matmul_op->set_attr_adj_x(transpose_x); - matmul_op->set_attr_adj_y(transpose_y); + matmul_op->set_input_x1(*x_node->data()); + matmul_op->set_input_x2(*y_node->data()); + matmul_op->set_attr_adj_x1(transpose_x); + matmul_op->set_attr_adj_x2(transpose_y); } if (fabs(alpha - 1.f) > 1e-6f) { diff --git a/lite/kernels/npu/bridges/utility.h b/lite/kernels/npu/bridges/utility.h index 6d7dc5891f..ccb832317a 100644 --- a/lite/kernels/npu/bridges/utility.h +++ b/lite/kernels/npu/bridges/utility.h @@ -20,11 +20,11 @@ #include #include #include "graph/buffer.h" +#include "graph/compatible/operator_reg.h" #include "graph/graph.h" #include "graph/model.h" #include "graph/op/all_ops.h" #include "graph/operator.h" -#include "graph/operator_reg.h" #include "lite/core/op_lite.h" #include "lite/utils/macros.h" @@ -97,25 +97,26 @@ REG_OP(Pad) /* * Multiplies slices of two tensors in batches. * - * x : The input tensor - * y : The input tensor + * x1 : The input tensor + * x2 : The input tensor * - * z : The output tensor + * y : The output tensor * - * adj_x : adj_x is true, the input tensor x is transposed, otherwise - * it will not be transposed. Default is false (The current version only - * supports false). - * adj_y : adj_y is true, the input tensor y is transposed, otherwise - * it will not be transposed. Default is false. + * adj_x1 : adj_x1 is true, the input tensor x1 is transposed, + * otherwise it will not be transposed. + * Default is false (The current version only supports false). + * adj_x2 : adj_x2 is true, the input tensor x2 is transposed, + * otherwise it will not be transposed. + * Default is false. * - * 100.320.010.010 + * 100.320.010.010 */ REG_OP(BatchMatMul) - .INPUT(x, TensorType({DT_FLOAT})) - .INPUT(y, TensorType({DT_FLOAT})) - .OUTPUT(z, TensorType({DT_FLOAT})) - .ATTR(adj_x, AttrValue::BOOL{false}) - .ATTR(adj_y, AttrValue::BOOL{false}) + .INPUT(x1, TensorType({DT_FLOAT})) + .INPUT(x2, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .ATTR(adj_x1, AttrValue::BOOL{false}) + .ATTR(adj_x2, AttrValue::BOOL{false}) .OP_END() } // namespace ge diff --git a/lite/kernels/npu/subgraph_compute.cc b/lite/kernels/npu/subgraph_compute.cc index 1baa5a0de4..6afb445e0e 100644 --- a/lite/kernels/npu/subgraph_compute.cc +++ b/lite/kernels/npu/subgraph_compute.cc @@ -15,6 +15,8 @@ #include "lite/kernels/npu/subgraph_compute.h" #include #include +#include +#include #include #include "hiai_ir_build.h" // NOLINT #include "lite/backends/npu/device.h" @@ -22,192 +24,276 @@ #include "lite/kernels/npu/bridges/graph.h" #include "lite/kernels/npu/bridges/paddle_use_bridges.h" #include "lite/kernels/npu/bridges/utility.h" +#include "lite/utils/io.h" +#include "lite/utils/md5.h" namespace paddle { namespace lite { namespace kernels { namespace npu { -int SubgraphEngine::BuildDeviceProgram() { +// Generate the model name by using md5 hashes based on: +// 1. the sorted variable input names +// 2. the shapes of the origin input tensors +// 3. the sorted variable output names +std::string DeviceProgram::GenerateModelName( + const std::vector& input_names, + const std::vector& output_names, + const std::vector>& origin_idims) { + std::ostringstream os; + CHECK_EQ(input_names.size(), origin_idims.size()); + for (int i = 0; i < input_names.size(); i++) { + os << input_names[i]; + for (auto dim : origin_idims[i]) { + os << dim; + } + } + for (auto output_name : output_names) { + os << output_name; + } + return MD5(os.str()); +} + +// Deserialize the generated model, the precisions and dimensions of the origin +// output tensors of the subgraph op into files +bool DeviceProgram::LoadFromCacheFile( + const std::vector& input_names, + const std::vector& output_names, + const std::vector>& origin_idims, + const std::string& model_cache_dir) { + // Generate the model name if not initialized + if (model_name_.empty()) { + model_name_ = GenerateModelName(input_names, output_names, origin_idims); + } + // Load from the cached model file, return a HiAI model manager client for + // inference + auto model_path = model_cache_dir + "/" + model_name_ + ".om"; + VLOG(3) << "[NPU] Load model from " << model_path; + std::vector model_buffer; + if (!ReadFile(model_path, &model_buffer)) { + LOG(WARNING) << "[NPU] read from " << model_path << " failed!"; + return false; + } + bool model_comp = false; + model_client_ = + lite::npu::Device::Global().Load(model_name_, &model_buffer, &model_comp); + if (!model_client_) { + LOG(WARNING) << "[NPU] Load model failed!"; + return false; + } + // Rewrite with the compatible model data if the cached + // model file is incompatible with the current device + if (!model_comp) { + VLOG(3) << "[NPU] Export the compatible model to " << model_path; + if (!WriteFile(model_path, model_buffer)) { + LOG(WARNING) << "[NPU] Open " << model_path << " for writting failed!"; + } + } + // Deserialize the precisions and shapes of the origin output tensors from the + // cached configuration file + auto config_path = model_cache_dir + "/" + model_name_ + ".cfg"; + VLOG(3) << "[NPU] Load configuration from " << config_path; + std::vector config_buffer; + if (!ReadFile(config_path, &config_buffer)) { + LOG(WARNING) << "[NPU] read from " << config_path << " failed!"; + return false; + } + std::string config_str(config_buffer.begin(), config_buffer.end()); + // Parse the precision and shapes of the output tensors + auto output_options = Split(config_str, ";"); + CHECK_EQ(output_options.size(), output_names.size()); + origin_otypes_.resize(output_names.size()); + origin_odims_.resize(output_names.size()); + for (int i = 0; i < output_names.size(); i++) { + auto items = Split(output_options[i], ":"); + CHECK_EQ(items.size(), 2); // precision and shapes + origin_otypes_[i] = static_cast(std::stoi(items[0])); + origin_odims_[i] = Split(items[1], ","); + } + return true; +} + +bool DeviceProgram::BuildGraphAndCacheToFile( + const std::vector& origin_program, + const std::vector& input_names, + const std::vector& output_names, + const std::vector>& origin_idims, + const std::vector& origin_otensors, + const std::string& model_cache_dir) { + // Generate the model name if not initialized + if (model_name_.empty()) { + model_name_ = GenerateModelName(input_names, output_names, origin_idims); + } + // Convert all of ops and their input vars and weights to HiAI IR nodes, + // then added them into the HiAI IR graph int status = 0; - // Convert all of ops and their input vars and weights and added into the NPU - // HiAI IR graph + CHECK(!origin_program.empty()) << "no instructions"; subgraph::npu::Graph graph; const auto& bridges = subgraph::Registry::Instance(); - for (auto& inst : origin_program_) { + for (auto& inst : origin_program) { auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); op->InferShape(); std::string op_type = op->op_info()->Type(); if (!bridges.Exists(op_type, TARGET(kNPU))) { - return subgraph::FAILED; + return false; } auto kernel = inst.kernel(); status |= bridges.Select(op_type, TARGET(kNPU))( reinterpret_cast(&graph), op, const_cast(kernel)); if (subgraph::CHECK_FAILED(status)) { - return subgraph::FAILED; + return false; } } - // Collect the valid input and output nodes in the HiAI IR graph and update - // the input and output names - device_inames_.clear(); - device_onames_.clear(); + // Collect the input and output nodes of the HiAI IR graph std::vector device_inodes; - std::vector device_onodes; - for (auto& input_name : input_names_) { - if (graph.Has(input_name)) { - if (graph.Get(input_name)->is_data()) { - device_inodes.push_back(*graph.Get(input_name)->data()); - device_inames_.push_back(input_name); - } else { - LOG(WARNING) << "[NPU] Input node " << input_name - << " is ignored because it is not a data node."; - } - } else { - LOG(WARNING) << "[NPU] Input node " << input_name - << " is ignored because it does not exist."; - } + for (size_t i = 0; i < input_names.size(); i++) { + CHECK(graph.Has(input_names[i]) && graph.Get(input_names[i])->is_data()); + device_inodes.push_back(*graph.Get(input_names[i])->data()); } - for (auto& output_name : output_names_) { - if (graph.Has(output_name)) { - device_onodes.push_back(*graph.Get(output_name)->data()); - device_onames_.push_back(output_name); - } else { - LOG(WARNING) << "[NPU] Output node " << output_name - << " is ignored because it does not exist."; - } - } - CHECK(!device_inames_.empty()) - << "[NPU] No input nodes found for building NPU model"; - CHECK(!device_onames_.empty()) - << "[NPU] No output nodes found for building NPU model"; - - // Build the HiAI IR graph to HiAI om model as the device program - if (device_program_map_.count(inputs_shape_) > 0) { - return status; + std::vector device_onodes; + for (size_t i = 0; i < output_names.size(); i++) { + CHECK(graph.Has(output_names[i])); + device_onodes.push_back(*graph.Get(output_names[i])->data()); } - auto device_client = lite::npu::Device::Global().Build( - model_name_, device_inodes, device_onodes); - if (device_client == nullptr) { + // Build the HiAI IR graph to the HiAI om model + std::vector model_buffer; + if (!lite::npu::Device::Global().Build( + device_inodes, device_onodes, &model_buffer)) { LOG(WARNING) << "[NPU] Build model failed!"; - return subgraph::FAILED; + return false; } - auto device_program = std::make_shared(device_client); - device_program_map_[inputs_shape_] = device_program; - - // Query and check the dimensions of valid input and output tensors - std::vector device_idims, device_odims; - if (device_program->client->GetModelIOTensorDim( - model_name_, device_idims, device_odims) != hiai::AI_SUCCESS) { - LOG(WARNING) - << "[NPU] Get the dimensions of input and output tensors failed!"; - return subgraph::FAILED; + // Load the HiAI om model and create a HiAI model manager client(from HiAI + // Service) to run inference. + bool model_comp = true; + model_client_ = + lite::npu::Device::Global().Load(model_name_, &model_buffer, &model_comp); + if (!model_client_) { + LOG(WARNING) << "[NPU] Load model failed!"; + return false; } - device_program->device_idims = device_idims; - device_program->device_odims = device_odims; - - CHECK_EQ(device_idims.size(), device_inames_.size()); - CHECK_EQ(device_odims.size(), device_onames_.size()); - origin_idims_.resize(device_inames_.size()); - origin_itensors_.resize(device_inames_.size()); - device_itensors_.resize(device_inames_.size()); - origin_odims_.resize(device_onames_.size()); - origin_otensors_.resize(device_onames_.size()); - device_otensors_.resize(device_onames_.size()); - - for (int i = 0; i < device_inames_.size(); i++) { - auto node = graph.Get(device_inames_[i]); - auto precision = node->precision(); - auto layout = node->layout(); - origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]); - CHECK(origin_itensors_[i]); - origin_idims_[i] = origin_itensors_[i]->dims(); - VLOG(3) << "[NPU] Inputs[" << i << "] name: " << device_inames_[i] - << " precision: " << PrecisionToStr(precision) - << " layout: " << DataLayoutToStr(layout) << " dims: {" - << device_idims[i].GetNumber() << "," - << device_idims[i].GetChannel() << "," - << device_idims[i].GetHeight() << "," << device_idims[i].GetWidth() - << "}"; - // Prepare the device input tensors - CHECK_EQ(origin_idims_[i].production(), - device_idims[i].GetNumber() * device_idims[i].GetChannel() * - device_idims[i].GetHeight() * device_idims[i].GetWidth()); - device_itensors_[i].reset(new hiai::AiTensor); - device_itensors_[i]->Init(&(device_idims[i])); + // Update the precison and dimensions of the origin output tensors + CHECK_EQ(origin_otensors.size(), output_names.size()); + origin_otypes_.resize(output_names.size()); + origin_odims_.resize(output_names.size()); + for (size_t i = 0; i < output_names.size(); i++) { + origin_otypes_[i] = graph.Get(output_names[i])->precision(); + origin_odims_[i] = origin_otensors[i]->dims().Vectorize(); } - device_program->origin_idims = origin_idims_; - - for (int i = 0; i < device_onames_.size(); i++) { - auto node = graph.Get(device_onames_[i]); - auto precision = node->precision(); - auto layout = node->layout(); - origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]); - CHECK(origin_otensors_[i]); - origin_odims_[i] = origin_otensors_[i]->dims(); - VLOG(3) << "[NPU] Outputs[" << i << "] name: " << device_onames_[i] - << " precision: " << PrecisionToStr(precision) - << " layout: " << DataLayoutToStr(layout) << " dims: {" - << device_odims[i].GetNumber() << "," - << device_odims[i].GetChannel() << "," - << device_odims[i].GetHeight() << "," << device_odims[i].GetWidth() - << "}"; - // Prepare the device output tensors - switch (precision) { - case PRECISION(kFloat): - origin_otensors_[i]->mutable_data(); - break; - case PRECISION(kBool): - origin_otensors_[i]->mutable_data(); - break; - case PRECISION(kInt8): - origin_otensors_[i]->mutable_data(); - break; - case PRECISION(kInt16): - origin_otensors_[i]->mutable_data(); - break; - case PRECISION(kInt32): - origin_otensors_[i]->mutable_data(); - break; - case PRECISION(kInt64): - origin_otensors_[i]->mutable_data(); - break; - default: - LOG(FATAL) << "[NPU] " << device_onames_[i] - << " can't mutable data with precision type " - << PrecisionToStr(precision); - break; + if (!model_cache_dir.empty()) { + // Save the generated model to file, used for the model caching or the + // offline model generation + auto model_path = model_cache_dir + "/" + model_name_ + ".om"; + VLOG(3) << "[NPU] Save model to " << model_path; + if (!WriteFile(model_path, model_buffer)) { + LOG(WARNING) << "[NPU] Open " << model_path << " for writting failed!"; + } + // Serialize the precisions and shapes of the origin output tensors into the + // configuration file + std::ostringstream os; + for (int i = 0; i < output_names.size(); i++) { + os << static_cast(origin_otypes_[i]) << ":"; + for (auto dim : origin_odims_[i]) { + os << dim << ","; + } + os << ";"; + } + auto str = os.str(); + std::vector config_buffer(str.begin(), str.end()); + auto config_path = model_cache_dir + "/" + model_name_ + ".cfg"; + VLOG(3) << "[NPU] Save configuration to " << config_path; + if (!WriteFile(config_path, config_buffer)) { + LOG(WARNING) << "[NPU] Open " << config_path << " for writting failed!"; } - device_program->origin_odims = origin_odims_; - - CHECK_EQ(origin_odims_[i].production(), - device_odims[i].GetNumber() * device_odims[i].GetChannel() * - device_odims[i].GetHeight() * device_odims[i].GetWidth()); - device_otensors_[i].reset(new hiai::AiTensor); - device_otensors_[i]->Init(&(device_odims[i])); } - return status; + return true; } -int SubgraphEngine::LaunchDeviceProgram() { - // Copy the data of origin input tensors to the buffer of input HiAI tensors - // init device_itensors_, device_otensors_, origin_otensors_ - auto device_program = device_program_map_[inputs_shape_]; - for (size_t i = 0; i < device_itensors_.size(); i++) { - device_itensors_[i]->Init(&(device_program->device_idims[i])); - std::memcpy(device_itensors_[i]->GetBuffer(), - origin_itensors_[i]->raw_data(), - origin_itensors_[i]->memory_size()); +bool DeviceProgram::ShareBufferWithOriginTensors( + const std::vector& input_names, + const std::vector& output_names, + std::vector* origin_itensors, + std::vector* origin_otensors, + std::vector>* device_itensors, + std::vector>* device_otensors) { + CHECK(!model_name_.empty() && model_client_); + // Query the dimensions of the device input and output tensors if not + // initialized + if (device_idims_.empty() || device_odims_.empty()) { + if (model_client_->GetModelIOTensorDim( + model_name_, device_idims_, device_odims_) != hiai::AI_SUCCESS) { + LOG(WARNING) + << "[NPU] Get the dimensions of input and output tensors failed!"; + return false; + } } - for (size_t i = 0; i < device_otensors_.size(); i++) { - device_otensors_[i]->Init(&(device_program->device_odims[i])); + // Check the dimensions of the device tensors and the origin tensors + CHECK_EQ(device_itensors->size(), input_names.size()); + CHECK_EQ(device_otensors->size(), output_names.size()); + CHECK_EQ(origin_otypes_.size(), output_names.size()); + CHECK_EQ(origin_odims_.size(), output_names.size()); + CHECK_EQ(device_idims_.size(), input_names.size()); + CHECK_EQ(device_odims_.size(), output_names.size()); + for (int i = 0; i < input_names.size(); i++) { + VLOG(3) << "[NPU] Inputs[" << i << "] name: " << input_names[i] + << " origin dims:" << (*origin_itensors)[i]->dims().repr() + << " device dims: {" << device_idims_[i].GetNumber() << "," + << device_idims_[i].GetChannel() << "," + << device_idims_[i].GetHeight() << "," + << device_idims_[i].GetWidth() << "}"; + CHECK_EQ((*origin_itensors)[i]->dims().production(), + device_idims_[i].GetNumber() * device_idims_[i].GetChannel() * + device_idims_[i].GetHeight() * device_idims_[i].GetWidth()); + VLOG(3) << "[NPU] Init the input tensors for the device program and share " + "their buffers with the origin input tensors"; + // reinit device tensor will free shared buffer, so copy data to a tmp + // tensor + Tensor tmp; + tmp.CopyDataFrom(*(*origin_itensors)[i]); + (*device_itensors)[i]->Init(&(device_idims_[i])); + + std::memcpy( + (*device_itensors)[i]->GetBuffer(), tmp.raw_data(), tmp.memory_size()); + + // Share data buf between device_itensor and origin_itensor + std::shared_ptr buffer = + std::make_shared((*device_itensors)[i]->GetBuffer(), + lite_api::TargetType::kHost, + (*device_itensors)[i]->GetSize()); + (*origin_itensors)[i]->ResetBuffer(buffer, + (*device_itensors)[i]->GetSize()); } - for (size_t i = 0; i < origin_otensors_.size(); i++) { - origin_otensors_[i]->Resize(device_program->origin_odims[i]); + for (int i = 0; i < output_names.size(); i++) { + (*origin_otensors)[i]->set_precision(origin_otypes_[i]); + (*origin_otensors)[i]->Resize(origin_odims_[i]); + VLOG(3) << "[NPU] Outputs[" << i << "] name: " << output_names[i] + << " origin dims:" << (*origin_otensors)[i]->dims().repr() + << " device dims: {" << device_odims_[i].GetNumber() << "," + << device_odims_[i].GetChannel() << "," + << device_odims_[i].GetHeight() << "," + << device_odims_[i].GetWidth() << "}"; + CHECK_EQ((*origin_otensors)[i]->dims().production(), + device_odims_[i].GetNumber() * device_odims_[i].GetChannel() * + device_odims_[i].GetHeight() * device_odims_[i].GetWidth()); + (*device_otensors)[i]->Init(&(device_odims_[i])); + VLOG(3) << "[NPU] Init the output tensors for the device program and share " + "their buffers with the origin output tensors"; + // Share data buf between device_itensor and origin_itensor + std::shared_ptr buffer = + std::make_shared((*device_otensors)[i]->GetBuffer(), + lite_api::TargetType::kHost, + (*device_otensors)[i]->GetSize()); + (*origin_otensors)[i]->ResetBuffer(buffer, + (*device_otensors)[i]->GetSize()); } + return true; +} +bool DeviceProgram::ZeroCopyRun( + std::vector>* device_itensors, + std::vector>* device_otensors) { + CHECK(!model_name_.empty() && model_client_); // Run the HiAI model by name std::string key = "model_name"; // Note: key seems must be model_name hiai::AiContext model_context; @@ -219,30 +305,87 @@ int SubgraphEngine::LaunchDeviceProgram() { }; int istamp; auto start_time = GetCurrentUS(); - CHECK_EQ(device_program->client->Process( - model_context, device_itensors_, device_otensors_, 1000, istamp), + CHECK_EQ(model_client_->Process( + model_context, *device_itensors, *device_otensors, 1000, istamp), hiai::AI_SUCCESS); VLOG(3) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us"; + return true; +} + +bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() { + // Obtain the origin input tensors, and create the origin output + // tensors(Don't try to access them before launch the device program or the + // origin program) + PrepareWorkspaceForOriginProgram(); + // Create the device input and output tensors, but don't initialize them + // with the dimensions + device_itensors_.resize(input_names_.size()); + for (int i = 0; i < input_names_.size(); i++) { + device_itensors_[i].reset(new hiai::AiTensor); + CHECK(device_itensors_[i]); + } + device_otensors_.resize(output_names_.size()); + for (int i = 0; i < output_names_.size(); i++) { + device_otensors_[i].reset(new hiai::AiTensor); + CHECK(device_otensors_[i]); + } + return true; +} - // Copy the data of output HiAI tensor to the buffer of origin output tensors - for (size_t i = 0; i < device_otensors_.size(); i++) { - std::memcpy(const_cast(origin_otensors_[i]->raw_data()), - device_otensors_[i]->GetBuffer(), - device_otensors_[i]->GetSize()); +bool SubgraphEngine::BuildDeviceProgram() { + // Check if the cache device program exists + if (!device_programs_.count(origin_idims_)) { + auto device_program = std::make_shared(); + // Obtain the model cache dir from the NPU Context of the subgraph op + auto model_cache_dir = ctx_->As().SubgraphModelCacheDir(); + VLOG(3) << "[NPU] Getting subgraph model_cache_dir is: " << model_cache_dir; + // Check and load if the cached model and configuration file exists + if (model_cache_dir.empty() || + !device_program->LoadFromCacheFile( + input_names_, output_names_, origin_idims_, model_cache_dir)) { + // Build the model online, including converting the paddle ops to the HiAI + // IR nodes, building the HiAI IR graph to the om model, then load it as a + // new HiAI model manager client for inference. + if (origin_program_.empty()) { + BuildOriginProgram(); + } + CHECK(!origin_program_.empty()) << "no instructions"; + if (!device_program->BuildGraphAndCacheToFile(origin_program_, + input_names_, + output_names_, + origin_idims_, + origin_otensors_, + model_cache_dir)) { + return false; + } + } + if (device_program->model_client_ == nullptr) { + return false; + } + device_programs_[origin_idims_] = device_program; } - return 0; + auto device_program = device_programs_[origin_idims_]; + CHECK(device_program && device_program->model_client_); + return device_program->ShareBufferWithOriginTensors(input_names_, + output_names_, + &origin_itensors_, + &origin_otensors_, + &device_itensors_, + &device_otensors_); } -bool SubgraphEngine::InputShapeChanged() { - std::vector> new_shape; - for (auto origin_itensor : origin_itensors_) { - new_shape.push_back(origin_itensor->dims().Vectorize()); +bool SubgraphEngine::LaunchDeviceProgram() { + // Roll back to launch the origin program if the device program can't be + // found or the model client isn't initialized. + if (device_programs_.count(origin_idims_) == 0 || + device_programs_[origin_idims_]->model_client_ == nullptr) { + return LaunchOriginProgram(); } - inputs_shape_ = new_shape; - if (device_program_map_.count(inputs_shape_) > 0) { - return false; + auto device_program = device_programs_[origin_idims_]; + if (!device_program->model_client_) { + return LaunchOriginProgram(); } - return true; + return device_program->ZeroCopyRun(&device_itensors_, &device_otensors_); } void SubgraphCompute::PrepareForRun() { @@ -254,12 +397,11 @@ void SubgraphCompute::PrepareForRun() { param.output_data_names, param.scope)); CHECK(engine_); - engine_->Build(); } void SubgraphCompute::Run() { CHECK(engine_); - engine_->Launch(); + engine_->Run(); } } // namespace npu diff --git a/lite/kernels/npu/subgraph_compute.h b/lite/kernels/npu/subgraph_compute.h index 801f61b036..33321a7789 100644 --- a/lite/kernels/npu/subgraph_compute.h +++ b/lite/kernels/npu/subgraph_compute.h @@ -28,40 +28,65 @@ namespace lite { namespace kernels { namespace npu { +class DeviceProgram { + public: + DeviceProgram() {} + ~DeviceProgram() {} + std::string GenerateModelName( + const std::vector& input_names, + const std::vector& output_names, + const std::vector>& origin_idims); + bool LoadFromCacheFile(const std::vector& input_names, + const std::vector& output_names, + const std::vector>& origin_idims, + const std::string& model_cache_dir); + bool BuildGraphAndCacheToFile( + const std::vector& origin_program, + const std::vector& input_names, + const std::vector& output_names, + const std::vector>& origin_idims, + const std::vector& origin_otensors, + const std::string& model_cache_dir); + bool ShareBufferWithOriginTensors( + const std::vector& input_names, + const std::vector& output_names, + std::vector* origin_itensors, + std::vector* origin_otensors, + std::vector>* device_itensors, + std::vector>* device_otensors); + bool ZeroCopyRun( + std::vector>* device_itensors, + std::vector>* device_otensors); + + public: + std::string model_name_{""}; + std::shared_ptr model_client_{nullptr}; + std::vector> origin_odims_; + std::vector origin_otypes_; + std::vector device_idims_{}; + std::vector device_odims_{}; +}; + class SubgraphEngine : public subgraph::Engine { public: - SubgraphEngine(KernelContext *ctx, + SubgraphEngine(KernelContext* ctx, int block_idx, - cpp::BlockDesc *block_desc, - const std::vector &input_names, - const std::vector &output_names, - Scope *scope) + cpp::BlockDesc* block_desc, + const std::vector& input_names, + const std::vector& output_names, + Scope* scope) : subgraph::Engine( ctx, block_idx, block_desc, input_names, output_names, scope) {} - struct device_program_t { - explicit device_program_t(std::shared_ptr _client) - : client(_client) {} - std::shared_ptr client{nullptr}; - std::vector origin_idims{}; - std::vector origin_odims{}; - std::vector device_idims{}; - std::vector device_odims{}; - }; - protected: - int BuildDeviceProgram() override; - int LaunchDeviceProgram() override; - bool InputShapeChanged() override; + bool PrepareWorkspaceForDeviceProgram() override; + bool BuildDeviceProgram() override; + bool LaunchDeviceProgram() override; - std::string model_name_{"model.om"}; - std::vector> inputs_shape_{}; - std::map>, std::shared_ptr> - device_program_map_{}; - std::vector device_inames_{}; - std::vector device_onames_{}; std::vector> device_itensors_{}; std::vector> device_otensors_{}; + std::map>, std::shared_ptr> + device_programs_; }; class SubgraphCompute : public KernelLite { diff --git a/lite/utils/env.h b/lite/utils/env.h index 3048c84b42..f3bb8b58e1 100644 --- a/lite/utils/env.h +++ b/lite/utils/env.h @@ -22,6 +22,8 @@ #define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \ "SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE" +#define SUBGRAPH_DISABLE_ONLINE_MODE "SUBGRAPH_DISABLE_ONLINE_MODE" + namespace paddle { namespace lite { diff --git a/lite/utils/io.h b/lite/utils/io.h index 506901bad5..8d1f0fe52c 100644 --- a/lite/utils/io.h +++ b/lite/utils/io.h @@ -119,5 +119,40 @@ static std::vector ListDir(const std::string& path, return paths; } +static bool ReadFile(const std::string& filename, std::vector* contents) { + FILE* fp = fopen(filename.c_str(), "rb"); + if (!fp) return false; + fseek(fp, 0, SEEK_END); + size_t size = ftell(fp); + fseek(fp, 0, SEEK_SET); + contents->clear(); + contents->resize(size); + size_t offset = 0; + char* ptr = reinterpret_cast(&(contents->at(0))); + while (offset < size) { + size_t already_read = fread(ptr, 1, size - offset, fp); + offset += already_read; + ptr += already_read; + } + fclose(fp); + return true; +} + +static bool WriteFile(const std::string& filename, + const std::vector& contents) { + FILE* fp = fopen(filename.c_str(), "wb"); + if (!fp) return false; + size_t size = contents.size(); + size_t offset = 0; + const char* ptr = reinterpret_cast(&(contents.at(0))); + while (offset < size) { + size_t already_written = fwrite(ptr, 1, size - offset, fp); + offset += already_written; + ptr += already_written; + } + fclose(fp); + return true; +} + } // namespace lite } // namespace paddle diff --git a/lite/utils/md5.h b/lite/utils/md5.h new file mode 100644 index 0000000000..c2e972dd80 --- /dev/null +++ b/lite/utils/md5.h @@ -0,0 +1,104 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace paddle { +namespace lite { + +std::string MD5(std::string message) { + const uint32_t shiftAmounts[] = { + 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, + 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, + 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, + 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21}; + const uint32_t partsOfSines[] = { + 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, + 0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, + 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340, + 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, + 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, + 0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, + 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa, + 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, + 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, + 0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, + 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391}; + + uint32_t state[4]; + state[0] = 0x67452301; + state[1] = 0xefcdab89; + state[2] = 0x98badcfe; + state[3] = 0x10325476; + + // Pad with zeros + int size = ((((message.length() + 8) / 64) + 1) * 64) - 8; + uint8_t *buf = reinterpret_cast(calloc(size + 64, 1)); + memcpy(buf, message.c_str(), message.length()); + buf[message.length()] = 128; + uint32_t bits = 8 * message.length(); + memcpy(buf + size, &bits, 4); + +// Process at each 512-bit(64 bytes) chunk +#define LEFTROTATE(x, c) (((x) << (c)) | ((x) >> (32 - (c)))) + for (int offset = 0; offset < size; offset += 64) { + uint32_t A = state[0]; + uint32_t B = state[1]; + uint32_t C = state[2]; + uint32_t D = state[3]; + uint32_t *W = reinterpret_cast(buf + offset); + for (uint32_t i = 0; i < 64; i++) { + uint32_t F, g; + if (i < 16) { + F = (B & C) | ((~B) & D); + g = i; + } else if (i < 32) { + F = (D & B) | ((~D) & C); + g = (5 * i + 1) % 16; + } else if (i < 48) { + F = B ^ C ^ D; + g = (3 * i + 5) % 16; + } else { + F = C ^ (B | (~D)); + g = (7 * i) % 16; + } + uint32_t T = D; + D = C; + C = B; + B = B + LEFTROTATE((A + F + partsOfSines[i] + W[g]), shiftAmounts[i]); + A = T; + } + state[0] += A; + state[1] += B; + state[2] += C; + state[3] += D; + } +#undef LEFTROTATE + free(buf); + + // Convert digest to string + std::string res; + res.reserve(16 << 1); + const uint8_t *digest = reinterpret_cast(state); + char hex[3]; + for (size_t i = 0; i < 16; i++) { + snprintf(hex, sizeof(hex), "%02x", digest[i]); + res.append(hex); + } + return res; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/utils/string.h b/lite/utils/string.h index ada51d0b85..b1aaf5d6c5 100644 --- a/lite/utils/string.h +++ b/lite/utils/string.h @@ -60,6 +60,38 @@ static std::string to_string(const T& v) { return ss.str(); } +static std::string to_string(int index) { + const int BUFFER_LENGTH = 15; + char buffer[BUFFER_LENGTH]; + snprintf(buffer, sizeof(buffer), "%d", index); + return std::string(buffer); +} + +template +static T parse_string(const std::string& v) { + return v; +} + +template <> +int32_t parse_string(const std::string& v) { + return std::stoi(v); +} + +template <> +int64_t parse_string(const std::string& v) { + return std::stoll(v); +} + +template <> +float parse_string(const std::string& v) { + return std::stof(v); +} + +template <> +double parse_string(const std::string& v) { + return std::stod(v); +} + template std::string Join(const std::vector& vec, const std::string& delim) { if (vec.empty()) return ""; @@ -84,19 +116,20 @@ static std::string Repr(const std::vector& v) { return "{" + Join(tmp, ",") + "}"; } -static std::vector Split(const std::string& original, - const std::string& separator) { - std::vector results; +template +static std::vector Split(const std::string& original, + const std::string& separator) { + std::vector results; std::string::size_type pos1, pos2; pos2 = original.find(separator); pos1 = 0; while (std::string::npos != pos2) { - results.push_back(original.substr(pos1, pos2 - pos1)); + results.push_back(parse_string(original.substr(pos1, pos2 - pos1))); pos1 = pos2 + separator.size(); pos2 = original.find(separator, pos1); } if (pos1 != original.length()) { - results.push_back(original.substr(pos1)); + results.push_back(parse_string(original.substr(pos1))); } return results; } -- GitLab