未验证 提交 ac897177 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] enhance cache offline model, test=develop (#3805) (#3931)

* [NPU] enhance cache offline model, test=develop
上级 813f17ba
......@@ -117,3 +117,6 @@ metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models
metal/MobileNetDemo/MobileNetDemo/Resources
build*
# hiai libs
ai_ddk_lib*
......@@ -35,7 +35,11 @@ endif()
if(NOT DEFINED ANDROID_API_LEVEL)
set(ANDROID_API_LEVEL "23")
if(ARM_TARGET_ARCH_ABI STREQUAL "armv7")
set(ANDROID_API_LEVEL "22")
if(LITE_WITH_NPU AND NOT LITE_ON_TINY_PUBLISH)
set(ANDROID_API_LEVEL "24") # HIAI DDK depends on android-24
else()
set(ANDROID_API_LEVEL "22")
endif()
endif()
endif()
......
......@@ -70,6 +70,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
raw_predictor_.Build(config, places, passes);
mode_ = config.power_mode();
threads_ = config.threads();
#ifdef LITE_WITH_NPU
Context<TargetType::kNPU>::SetSubgraphModelCacheDir(
config.subgraph_model_cache_dir());
#endif
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL) && !defined(__APPLE__)
int num_threads = config.x86_math_library_num_threads();
......
......@@ -36,6 +36,11 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) {
}
mode_ = config.power_mode();
threads_ = config.threads();
#ifdef LITE_WITH_NPU
Context<TargetType::kNPU>::SetSubgraphModelCacheDir(
config.subgraph_model_cache_dir());
#endif
}
std::unique_ptr<lite_api::Tensor> LightPredictorImpl::GetInput(int i) {
......
......@@ -118,6 +118,8 @@ class LITE_API ConfigBase {
std::string model_dir_;
int threads_{1};
PowerMode mode_{LITE_POWER_NO_BIND};
// to save subgraph model for npu/xpu/...
std::string subgraph_model_cache_dir_{""};
public:
explicit ConfigBase(PowerMode mode = LITE_POWER_NO_BIND, int threads = 1);
......@@ -130,6 +132,13 @@ class LITE_API ConfigBase {
// set Thread
void set_threads(int threads);
int threads() const { return threads_; }
// set subgraph_model_dir
void set_subgraph_model_cache_dir(std::string subgraph_model_cache_dir) {
subgraph_model_cache_dir_ = subgraph_model_cache_dir;
}
const std::string& subgraph_model_cache_dir() const {
return subgraph_model_cache_dir_;
}
};
/// CxxConfig is the config for the Full feature predictor.
......
......@@ -19,52 +19,122 @@ namespace paddle {
namespace lite {
namespace npu {
std::shared_ptr<hiai::AiModelMngerClient> Device::Build(
const std::string model_name, // NOLINT
std::vector<ge::Operator>& input_nodes, // NOLINT
std::vector<ge::Operator>& output_nodes // NOLINT
) {
VLOG(3) << "[NPU] Build model";
// Build the HiAI IR graph to the HiAI om model
ge::Graph ir_graph("graph");
ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
ge::Model om_model("model", "model");
om_model.SetGraph(ir_graph);
domi::HiaiIrBuild ir_build;
domi::ModelBufferData om_model_buf;
if (!ir_build.CreateModelBuff(om_model, om_model_buf)) {
LOG(WARNING) << "[NPU] CreateModelBuff failed!";
return nullptr;
}
if (!ir_build.BuildIRModel(om_model, om_model_buf)) {
LOG(WARNING) << "[NPU] BuildIRModel failed!";
ir_build.ReleaseModelBuff(om_model_buf);
return nullptr;
}
std::shared_ptr<hiai::AiModelMngerClient> Device::Load(
const std::string& model_name,
std::vector<char>* model_buffer,
bool* model_comp) {
// Create a HiAI model manager client to load the HiAI om model
std::shared_ptr<hiai::AiModelMngerClient> model_client(
new hiai::AiModelMngerClient());
auto model_client = std::make_shared<hiai::AiModelMngerClient>();
if (model_client->Init(nullptr) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient init failed)!";
ir_build.ReleaseModelBuff(om_model_buf);
LOG(WARNING) << "[NPU] Init hiai model client failed!";
return nullptr;
}
// Check HiAI DDK version
const char* ddk_version = model_client->GetVersion();
if (ddk_version) {
LOG(INFO) << "[NPU] HiAI DDK version: " << ddk_version;
} else {
LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!";
}
// Check model compatibility
auto model_desc = std::make_shared<hiai::AiModelDescription>(
model_name, freq_level(), framework_type(), model_type(), device_type());
model_desc->SetModelBuffer(om_model_buf.data, om_model_buf.length);
std::vector<std::shared_ptr<hiai::AiModelDescription>> model_descs;
model_descs.push_back(model_desc);
model_desc->SetModelBuffer(
reinterpret_cast<const void*>(model_buffer->data()),
model_buffer->size());
if (!*model_comp &&
model_client->CheckModelCompatibility(*model_desc, *model_comp) !=
hiai::AI_SUCCESS) {
*model_comp = false;
VLOG(3) << "[NPU] model is NOT compatiblitiable, setting model_comp to "
<< *model_comp;
} else {
*model_comp = true;
VLOG(3) << "[NPU] model is compatiblitiable, setting model_comp to "
<< *model_comp;
}
// Rebuild and write the data of the compatible model to the model buffer
if (!*model_comp) {
std::shared_ptr<hiai::AiModelBuilder> model_builder =
std::make_shared<hiai::AiModelBuilder>(model_client);
hiai::MemBuffer* org_model_buffer = model_builder->InputMemBufferCreate(
reinterpret_cast<void*>(model_buffer->data()), model_buffer->size());
if (org_model_buffer) {
std::vector<hiai::MemBuffer*> org_model_buffers;
org_model_buffers.push_back(org_model_buffer);
hiai::MemBuffer* new_model_buffer = model_builder->OutputMemBufferCreate(
framework_type(), org_model_buffers);
// VLOG(3) << "[NPU] new model buffer memeory size is " <<
// new_model_buffer->GetMemBufferSize();
if (new_model_buffer) {
uint32_t new_model_size = 0;
if (model_builder->BuildModel(org_model_buffers,
new_model_buffer,
new_model_size) == hiai::AI_SUCCESS) {
// need to change to new_model_size as GetMemBufferSize is not
// correct.
model_buffer->resize(new_model_size);
memcpy(reinterpret_cast<void*>(model_buffer->data()),
new_model_buffer->GetMemBufferData(),
new_model_size);
// Reset the model buffer
model_desc->SetModelBuffer(
reinterpret_cast<const void*>(model_buffer->data()),
model_buffer->size());
VLOG(3) << "[NPU] Rebuild the compatible model done.";
} else {
LOG(WARNING) << "[NPU] Rebuild the compatible model failed!";
}
model_builder->MemBufferDestroy(new_model_buffer);
} else {
LOG(WARNING) << "[NPU] OutputMemBufferCreate failed!";
}
model_builder->MemBufferDestroy(org_model_buffer);
} else {
LOG(WARNING) << "[NPU] InputMemBufferCreate failed!";
}
}
// Load the compatible model
std::vector<std::shared_ptr<hiai::AiModelDescription>> model_descs{
model_desc};
if (model_client->Load(model_descs) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!";
ir_build.ReleaseModelBuff(om_model_buf);
return nullptr;
}
ir_build.ReleaseModelBuff(om_model_buf);
VLOG(3) << "[NPU] Build done";
VLOG(3) << "[NPU] Load model done.";
return model_client;
}
bool Device::Build(std::vector<ge::Operator>& input_nodes, // NOLINT
std::vector<ge::Operator>& output_nodes, // NOLINT
std::vector<char>* model_buffer) {
// Convert the HiAI IR graph to the HiAI om model
ge::Graph ir_graph("graph");
ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
ge::Model om_model("model", "model");
om_model.SetGraph(ir_graph);
// Build the HiAI om model, serialize and output it to the om buffer
domi::HiaiIrBuild ir_build;
domi::ModelBufferData om_buffer;
if (!ir_build.CreateModelBuff(om_model, om_buffer)) {
LOG(WARNING) << "[NPU] CreateModelBuff failed!";
return false;
}
if (!ir_build.BuildIRModel(om_model, om_buffer)) {
LOG(WARNING) << "[NPU] BuildIRModel failed!";
ir_build.ReleaseModelBuff(om_buffer);
return false;
}
model_buffer->resize(om_buffer.length);
memcpy(reinterpret_cast<void*>(model_buffer->data()),
reinterpret_cast<void*>(om_buffer.data),
om_buffer.length);
ir_build.ReleaseModelBuff(om_buffer);
VLOG(3) << "[NPU] Build model done.";
return true;
}
} // namespace npu
} // namespace lite
} // namespace paddle
......@@ -38,13 +38,18 @@ class Device {
int model_type() { return model_type_; }
int device_type() { return device_type_; }
// Load the HiAI om model from buffer, rebuild the model if it's incompatible
// with the current device, then create a HiAI model manager client(from HiAI
// Server) to run inference
std::shared_ptr<hiai::AiModelMngerClient> Load(
const std::string& model_name,
std::vector<char>* model_buffer,
bool* model_comp);
// Build the HiAI IR graph to om model, return HiAI model manager client to
// load om model and run inference.
std::shared_ptr<hiai::AiModelMngerClient> Build(
const std::string model_name, // NOLINT
std::vector<ge::Operator>& input_nodes, // NOLINT
std::vector<ge::Operator>& output_nodes // NOLINT
); // NOLINT
bool Build(std::vector<ge::Operator>& input_nodes, // NOLINT
std::vector<ge::Operator>& output_nodes, // NOLINT
std::vector<char>* model_buffer);
private:
int freq_level_{3};
......
......@@ -17,6 +17,10 @@
namespace paddle {
namespace lite {
#ifdef LITE_WITH_NPU
std::string Context<TargetType::kNPU>::subgraph_model_cache_dir_{""}; // NOLINT
#endif
#ifdef LITE_WITH_XPU
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0};
......
......@@ -85,6 +85,16 @@ class Context<TargetType::kNPU> {
NPUContext& operator=(const NPUContext& ctx) {}
std::string name() const { return "NPUContext"; }
static void SetSubgraphModelCacheDir(std::string subgraph_model_cache_dir) {
subgraph_model_cache_dir_ = subgraph_model_cache_dir;
}
static std::string SubgraphModelCacheDir() {
return subgraph_model_cache_dir_;
}
private:
static std::string subgraph_model_cache_dir_;
};
#endif
......
......@@ -426,73 +426,51 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
subgraph_op_desc.SetAttr<int32_t>("sub_block", sub_block_idx);
// Extract input and output nodes from the target subgraph
std::unordered_set<Node *> input_var_nodes;
std::unordered_set<Node *> idata_var_nodes;
std::unordered_set<Node *> weight_var_nodes;
std::unordered_set<Node *> output_var_nodes;
std::unordered_set<Node *> odata_var_nodes;
std::unordered_set<Node *> local_var_nodes;
std::unordered_set<Node *> unused_var_nodes;
ExtractInputsOutputs(subgraph_nodes,
&input_var_nodes,
&idata_var_nodes,
&weight_var_nodes,
&output_var_nodes,
&odata_var_nodes,
&local_var_nodes,
&unused_var_nodes);
// A simplified model without the original weight/local/unused nodes on the
// subgraph ops will be saved only if 'SUBGRAPH_DISABLE_ONLINE_MODE' is set to
// true and Predictor->Run(...), Predictor->Save(...) is called.
std::unordered_set<Node *> input_var_nodes(idata_var_nodes.begin(),
idata_var_nodes.end());
std::unordered_set<Node *> output_var_nodes(odata_var_nodes.begin(),
odata_var_nodes.end());
if (!GetBoolFromEnv(SUBGRAPH_DISABLE_ONLINE_MODE)) {
input_var_nodes.insert(weight_var_nodes.begin(), weight_var_nodes.end());
output_var_nodes.insert(local_var_nodes.begin(), local_var_nodes.end());
output_var_nodes.insert(unused_var_nodes.begin(), unused_var_nodes.end());
}
// Set input and output name mapping which stores the real inputs and
// outputs
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (auto &var_node : input_var_nodes) {
input_var_names.push_back(var_node->AsArg().name);
std::vector<std::string> idata_var_names;
std::vector<std::string> odata_var_names;
for (auto &var_node : idata_var_nodes) {
idata_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : output_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
for (auto &var_node : odata_var_nodes) {
odata_var_names.push_back(var_node->AsArg().name);
}
subgraph_op_desc.SetAttr<std::vector<std::string>>("input_data_names",
input_var_names);
idata_var_names);
subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names",
output_var_names);
// Set input/output scale values of input/output var nodes for
// type_precision_cast_pass.
std::vector<float> input_data_scales;
std::vector<float> output_data_scales;
for (auto &var_node : input_var_nodes) {
auto any_op_node = var_node->outlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("input_scale")) {
input_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("input_scale"));
}
}
for (auto &var_node : output_var_nodes) {
auto any_op_node = var_node->inlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("output_scale")) {
output_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("output_scale"));
}
}
if (input_data_scales.size() > 0) {
subgraph_op_desc.SetAttr<std::vector<float>>("input_data_scales",
input_data_scales);
}
if (output_data_scales.size() > 0) {
subgraph_op_desc.SetAttr<std::vector<float>>("output_data_scales",
output_data_scales);
}
odata_var_names);
// Set all of the inputs and outputs to the target subgraph op
// To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram()
for (auto &var_node : weight_var_nodes) {
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (auto &var_node : input_var_nodes) {
input_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : local_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : unused_var_nodes) {
for (auto &var_node : output_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
subgraph_op_desc.SetInput("Inputs", input_var_names);
......@@ -509,26 +487,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
for (auto &var_node : input_var_nodes) {
IR_NODE_LINK_TO(var_node, subgraph_op_node);
}
for (auto &var_node : weight_var_nodes) {
IR_NODE_LINK_TO(var_node, subgraph_op_node);
}
for (auto &var_node : output_var_nodes) {
IR_OP_VAR_LINK(subgraph_op_node, var_node);
}
for (auto &var_node : local_var_nodes) {
IR_OP_VAR_LINK(subgraph_op_node, var_node);
}
for (auto &var_node : unused_var_nodes) {
IR_OP_VAR_LINK(subgraph_op_node, var_node);
}
// Remove subgraph nodes and unused var nodes
auto nodes2rm = GetNodes2RM(subgraph_nodes,
{input_var_nodes,
weight_var_nodes,
output_var_nodes,
local_var_nodes,
unused_var_nodes});
auto nodes2rm =
GetNodes2RM(subgraph_nodes, {input_var_nodes, output_var_nodes});
GraphSafeRemoveNodes(graph, nodes2rm);
}
......@@ -603,7 +568,17 @@ std::unordered_set<const Node *> GetNodes2RM(
std::unordered_set<const Node *> nodes2rm(op_nodes.begin(), op_nodes.end());
for (auto &op_node : op_nodes) {
for (auto &var_node : op_node->inlinks) {
if (!nodes2rm.count(var_node)) {
bool skip = false;
// skip the var node which is used by any other ops that doesn't belong to
// the subgraph ops.
for (auto &out_op_node : var_node->outlinks) {
if (std::find(op_nodes.begin(), op_nodes.end(), out_op_node) !=
op_nodes.end()) {
skip = true;
break;
}
}
if (!skip && !nodes2rm.count(var_node)) {
nodes2rm.insert(var_node);
}
}
......
......@@ -25,6 +25,7 @@ DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shape of input tensors");
DEFINE_string(input_tensor_type, "float32", "data type of input tensors");
DEFINE_string(output_tensor_type, "float32", "data type of output tensors");
DEFINE_string(subgraph_model_cache_dir, "", "dir of subgraph model cache");
namespace paddle {
namespace lite {
......@@ -132,6 +133,7 @@ std::shared_ptr<lite_api::PaddlePredictor> TestModel(
mobile_config.set_model_from_file(optimized_model_dir + ".nb");
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
mobile_config.set_subgraph_model_cache_dir(FLAGS_subgraph_model_cache_dir);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensors(predictor, input_tensor_shape, input_tensor_type, 1);
// Run optimized model
......
......@@ -15,6 +15,7 @@
#include "lite/kernels/npu/bridges/engine.h"
#include <sys/time.h>
#include <time.h>
#include <algorithm>
#include <utility>
#include "lite/kernels/npu/bridges/registry.h"
......@@ -22,11 +23,50 @@ namespace paddle {
namespace lite {
namespace subgraph {
int Engine::BuildDeviceProgram() { return FAILED; }
Engine::Engine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
lite::Scope *scope)
: ctx_(ctx), block_idx_(block_idx), block_desc_(block_desc), scope_(scope) {
input_names_ = input_names;
output_names_ = output_names;
// Sort the name of input and output tensors, it's convenient for us to get
// the info of input and output tensors in the same order from the device
// program, because the result of subgraph division may be different but right
// at each call of the subgraph pass.
std::stable_sort(input_names_.begin(), input_names_.end());
std::stable_sort(output_names_.begin(), output_names_.end());
}
bool Engine::Run() {
if (is_first_epoch_) {
PrepareWorkspaceForDeviceProgram();
is_first_epoch_ = false;
}
if (InputShapeChanged()) {
BuildDeviceProgram();
}
return LaunchDeviceProgram();
}
int Engine::LaunchDeviceProgram() { return 0; }
bool Engine::PrepareWorkspaceForOriginProgram() {
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]);
CHECK(origin_itensors_[i]);
}
origin_otensors_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]);
CHECK(origin_otensors_[i]);
}
return true;
}
int Engine::BuildOriginProgram() {
bool Engine::BuildOriginProgram() {
// TODO(hong19860320) The block_desc need to be divided into subgraphs during
// the exection time. But only see them as a subgraph now.
origin_program_.clear();
......@@ -34,11 +74,14 @@ int Engine::BuildOriginProgram() {
auto op_desc = block_desc_->GetOp<cpp::OpDesc>(op_idx);
CHECK(op_desc);
std::string op_type = op_desc->Type();
// Create op and pick up the best kernel
auto op = LiteOpRegistry::Global().Create(op_desc->Type());
CHECK(op) << "no Op found for " << op_type;
op->Attach(*op_desc, scope_);
std::unique_ptr<KernelBase> picked_kernel;
if (op_desc->HasAttr(kKernelTypeAttr)) {
// Create op and pick up kernel according to the kKernelTypeAttr attribute
// Create op and pick up the best kernel according to the
// kKernelTypeAttr attribute
auto kernel_type = op_desc->GetAttr<std::string>(kKernelTypeAttr);
std::string alias;
Place place;
......@@ -48,12 +91,14 @@ int Engine::BuildOriginProgram() {
auto kernels = op->CreateKernels({place});
CHECK_GT(kernels.size(), 0u) << "No kernels found for " << op_type;
auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& it) {
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase> &it) {
return it->alias() == alias;
});
CHECK(it != kernels.end());
picked_kernel = std::move(*it);
} else {
// TODO(hong19860320) add kernel picking according to the type of input
// and output tensors
VLOG(3) << "The attr '" << kKernelTypeAttr
<< "' not found, pick the first kernel for " << op_type;
std::vector<std::unique_ptr<KernelBase>> kernels;
......@@ -74,49 +119,41 @@ int Engine::BuildOriginProgram() {
}
origin_program_.emplace_back(std::move(op), std::move(picked_kernel));
}
return 0;
CHECK(!origin_program_.empty()) << "no instructions";
return true;
}
int Engine::LaunchOriginProgram() {
for (auto& inst : origin_program_) {
auto op_type = inst.op()->op_info()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
inst.Run();
bool Engine::LaunchOriginProgram() {
if (origin_program_.empty()) {
BuildOriginProgram();
}
if (!origin_program_.empty()) {
for (auto &inst : origin_program_) {
auto op_type = inst.op()->op_info()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
inst.Run();
}
return true;
}
return 0;
return false;
}
int Engine::Build() {
// In order to attach all of the ops of the block desc, we need to build the
// original program firstly.
BuildOriginProgram();
// Run InferShape() of all of ops, and convert Paddle ops to NPU/XPU IR graph
build_device_program_status_ = BuildDeviceProgram();
return build_device_program_status_;
bool Engine::PrepareWorkspaceForDeviceProgram() {
return PrepareWorkspaceForOriginProgram();
}
bool Engine::BuildDeviceProgram() { return BuildOriginProgram(); }
bool Engine::LaunchDeviceProgram() { return LaunchOriginProgram(); }
bool Engine::InputShapeChanged() {
bool changed = false;
for (size_t i = 0; i < origin_itensors_.size(); i++) {
if (origin_itensors_[i]->dims() != origin_idims_[i]) {
return true;
}
}
return false;
}
int Engine::Launch() {
// Rebuild device program when the shapes of input tensors have been changed.
if (CHECK_SUCCESS(build_device_program_status_) &&
CHECK_REBUILD_WHEN_SHAPE_CHANGED(build_device_program_status_) &&
InputShapeChanged()) {
Build();
}
if (CHECK_FAILED(build_device_program_status_)) {
LaunchOriginProgram();
} else {
LaunchDeviceProgram();
auto origin_idim = origin_itensors_[i]->dims().Vectorize();
changed |= origin_idim != origin_idims_[i];
origin_idims_[i] = origin_idim;
}
return 0;
return changed;
}
} // namespace subgraph
......
......@@ -33,42 +33,33 @@ class Engine {
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
lite::Scope *scope)
: ctx_(ctx),
block_idx_(block_idx),
block_desc_(block_desc),
input_names_(input_names),
output_names_(output_names),
scope_(scope) {}
lite::Scope *scope);
virtual ~Engine() = default;
virtual int Build();
virtual int Launch();
virtual bool Run();
private:
Engine(const Engine &) = delete;
protected:
virtual int BuildDeviceProgram();
virtual int LaunchDeviceProgram();
virtual bool PrepareWorkspaceForOriginProgram();
virtual bool BuildOriginProgram();
virtual bool LaunchOriginProgram();
virtual int BuildOriginProgram();
virtual int LaunchOriginProgram();
virtual bool PrepareWorkspaceForDeviceProgram();
virtual bool BuildDeviceProgram();
virtual bool LaunchDeviceProgram();
virtual bool InputShapeChanged();
KernelContext *ctx_{nullptr};
int block_idx_;
cpp::BlockDesc *block_desc_;
int block_idx_{-1};
cpp::BlockDesc *block_desc_{nullptr};
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
Scope *scope_{nullptr};
// SUCCESS: device program build successed. FAILED: device program build
// failed. REBUILD_WHEN_SHAPE_CHANGED: device program build successed but need
// to rebuild when input shape changed.
int build_device_program_status_{0};
std::vector<DDim> origin_idims_;
std::vector<DDim> origin_odims_;
bool is_first_epoch_{true};
std::vector<std::vector<int64_t>> origin_idims_;
std::vector<Tensor *> origin_itensors_;
std::vector<Tensor *> origin_otensors_;
std::vector<Instruction> origin_program_;
......
......@@ -19,7 +19,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "graph/op/all_ops.h"
#include "graph/compatible/all_ops.h"
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
......
......@@ -94,10 +94,10 @@ int MatMulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} else {
matmul_node = graph->Add<ge::op::BatchMatMul>(out_name);
auto matmul_op = matmul_node->data<ge::op::BatchMatMul>();
matmul_op->set_input_x(*x_node->data());
matmul_op->set_input_y(*y_node->data());
matmul_op->set_attr_adj_x(transpose_x);
matmul_op->set_attr_adj_y(transpose_y);
matmul_op->set_input_x1(*x_node->data());
matmul_op->set_input_x2(*y_node->data());
matmul_op->set_attr_adj_x1(transpose_x);
matmul_op->set_attr_adj_x2(transpose_y);
}
if (fabs(alpha - 1.f) > 1e-6f) {
......
......@@ -20,11 +20,11 @@
#include <unordered_map>
#include <vector>
#include "graph/buffer.h"
#include "graph/compatible/operator_reg.h"
#include "graph/graph.h"
#include "graph/model.h"
#include "graph/op/all_ops.h"
#include "graph/operator.h"
#include "graph/operator_reg.h"
#include "lite/core/op_lite.h"
#include "lite/utils/macros.h"
......@@ -97,25 +97,26 @@ REG_OP(Pad)
/*
* Multiplies slices of two tensors in batches.
* <Input>
* x : The input tensor
* y : The input tensor
* x1 : The input tensor
* x2 : The input tensor
* <Output>
* z : The output tensor
* y : The output tensor
* <Attr>
* adj_x : adj_x is true, the input tensor x is transposed, otherwise
* it will not be transposed. Default is false (The current version only
* supports false).
* adj_y : adj_y is true, the input tensor y is transposed, otherwise
* it will not be transposed. Default is false.
* adj_x1 : adj_x1 is true, the input tensor x1 is transposed,
* otherwise it will not be transposed.
* Default is false (The current version only supports false).
* adj_x2 : adj_x2 is true, the input tensor x2 is transposed,
* otherwise it will not be transposed.
* Default is false.
* <Added in HiAI version>
* 100.320.010.010
* 100.320.010.010
*/
REG_OP(BatchMatMul)
.INPUT(x, TensorType({DT_FLOAT}))
.INPUT(y, TensorType({DT_FLOAT}))
.OUTPUT(z, TensorType({DT_FLOAT}))
.ATTR(adj_x, AttrValue::BOOL{false})
.ATTR(adj_y, AttrValue::BOOL{false})
.INPUT(x1, TensorType({DT_FLOAT}))
.INPUT(x2, TensorType({DT_FLOAT}))
.OUTPUT(y, TensorType({DT_FLOAT}))
.ATTR(adj_x1, AttrValue::BOOL{false})
.ATTR(adj_x2, AttrValue::BOOL{false})
.OP_END()
} // namespace ge
......
......@@ -15,6 +15,8 @@
#include "lite/kernels/npu/subgraph_compute.h"
#include <sys/time.h>
#include <time.h>
#include <algorithm>
#include <functional>
#include <utility>
#include "hiai_ir_build.h" // NOLINT
#include "lite/backends/npu/device.h"
......@@ -22,192 +24,276 @@
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/paddle_use_bridges.h"
#include "lite/kernels/npu/bridges/utility.h"
#include "lite/utils/io.h"
#include "lite/utils/md5.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
int SubgraphEngine::BuildDeviceProgram() {
// Generate the model name by using md5 hashes based on:
// 1. the sorted variable input names
// 2. the shapes of the origin input tensors
// 3. the sorted variable output names
std::string DeviceProgram::GenerateModelName(
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims) {
std::ostringstream os;
CHECK_EQ(input_names.size(), origin_idims.size());
for (int i = 0; i < input_names.size(); i++) {
os << input_names[i];
for (auto dim : origin_idims[i]) {
os << dim;
}
}
for (auto output_name : output_names) {
os << output_name;
}
return MD5(os.str());
}
// Deserialize the generated model, the precisions and dimensions of the origin
// output tensors of the subgraph op into files
bool DeviceProgram::LoadFromCacheFile(
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims,
const std::string& model_cache_dir) {
// Generate the model name if not initialized
if (model_name_.empty()) {
model_name_ = GenerateModelName(input_names, output_names, origin_idims);
}
// Load from the cached model file, return a HiAI model manager client for
// inference
auto model_path = model_cache_dir + "/" + model_name_ + ".om";
VLOG(3) << "[NPU] Load model from " << model_path;
std::vector<char> model_buffer;
if (!ReadFile(model_path, &model_buffer)) {
LOG(WARNING) << "[NPU] read from " << model_path << " failed!";
return false;
}
bool model_comp = false;
model_client_ =
lite::npu::Device::Global().Load(model_name_, &model_buffer, &model_comp);
if (!model_client_) {
LOG(WARNING) << "[NPU] Load model failed!";
return false;
}
// Rewrite with the compatible model data if the cached
// model file is incompatible with the current device
if (!model_comp) {
VLOG(3) << "[NPU] Export the compatible model to " << model_path;
if (!WriteFile(model_path, model_buffer)) {
LOG(WARNING) << "[NPU] Open " << model_path << " for writting failed!";
}
}
// Deserialize the precisions and shapes of the origin output tensors from the
// cached configuration file
auto config_path = model_cache_dir + "/" + model_name_ + ".cfg";
VLOG(3) << "[NPU] Load configuration from " << config_path;
std::vector<char> config_buffer;
if (!ReadFile(config_path, &config_buffer)) {
LOG(WARNING) << "[NPU] read from " << config_path << " failed!";
return false;
}
std::string config_str(config_buffer.begin(), config_buffer.end());
// Parse the precision and shapes of the output tensors
auto output_options = Split<std::string>(config_str, ";");
CHECK_EQ(output_options.size(), output_names.size());
origin_otypes_.resize(output_names.size());
origin_odims_.resize(output_names.size());
for (int i = 0; i < output_names.size(); i++) {
auto items = Split<std::string>(output_options[i], ":");
CHECK_EQ(items.size(), 2); // precision and shapes
origin_otypes_[i] = static_cast<PrecisionType>(std::stoi(items[0]));
origin_odims_[i] = Split<int64_t>(items[1], ",");
}
return true;
}
bool DeviceProgram::BuildGraphAndCacheToFile(
const std::vector<Instruction>& origin_program,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims,
const std::vector<Tensor*>& origin_otensors,
const std::string& model_cache_dir) {
// Generate the model name if not initialized
if (model_name_.empty()) {
model_name_ = GenerateModelName(input_names, output_names, origin_idims);
}
// Convert all of ops and their input vars and weights to HiAI IR nodes,
// then added them into the HiAI IR graph
int status = 0;
// Convert all of ops and their input vars and weights and added into the NPU
// HiAI IR graph
CHECK(!origin_program.empty()) << "no instructions";
subgraph::npu::Graph graph;
const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) {
for (auto& inst : origin_program) {
auto op = const_cast<OpLite*>(inst.op());
CHECK(op);
op->CheckShape();
op->InferShape();
std::string op_type = op->op_info()->Type();
if (!bridges.Exists(op_type, TARGET(kNPU))) {
return subgraph::FAILED;
return false;
}
auto kernel = inst.kernel();
status |= bridges.Select(op_type, TARGET(kNPU))(
reinterpret_cast<void*>(&graph), op, const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED;
return false;
}
}
// Collect the valid input and output nodes in the HiAI IR graph and update
// the input and output names
device_inames_.clear();
device_onames_.clear();
// Collect the input and output nodes of the HiAI IR graph
std::vector<ge::Operator> device_inodes;
std::vector<ge::Operator> device_onodes;
for (auto& input_name : input_names_) {
if (graph.Has(input_name)) {
if (graph.Get(input_name)->is_data()) {
device_inodes.push_back(*graph.Get(input_name)->data());
device_inames_.push_back(input_name);
} else {
LOG(WARNING) << "[NPU] Input node " << input_name
<< " is ignored because it is not a data node.";
}
} else {
LOG(WARNING) << "[NPU] Input node " << input_name
<< " is ignored because it does not exist.";
}
for (size_t i = 0; i < input_names.size(); i++) {
CHECK(graph.Has(input_names[i]) && graph.Get(input_names[i])->is_data());
device_inodes.push_back(*graph.Get(input_names[i])->data());
}
for (auto& output_name : output_names_) {
if (graph.Has(output_name)) {
device_onodes.push_back(*graph.Get(output_name)->data());
device_onames_.push_back(output_name);
} else {
LOG(WARNING) << "[NPU] Output node " << output_name
<< " is ignored because it does not exist.";
}
}
CHECK(!device_inames_.empty())
<< "[NPU] No input nodes found for building NPU model";
CHECK(!device_onames_.empty())
<< "[NPU] No output nodes found for building NPU model";
// Build the HiAI IR graph to HiAI om model as the device program
if (device_program_map_.count(inputs_shape_) > 0) {
return status;
std::vector<ge::Operator> device_onodes;
for (size_t i = 0; i < output_names.size(); i++) {
CHECK(graph.Has(output_names[i]));
device_onodes.push_back(*graph.Get(output_names[i])->data());
}
auto device_client = lite::npu::Device::Global().Build(
model_name_, device_inodes, device_onodes);
if (device_client == nullptr) {
// Build the HiAI IR graph to the HiAI om model
std::vector<char> model_buffer;
if (!lite::npu::Device::Global().Build(
device_inodes, device_onodes, &model_buffer)) {
LOG(WARNING) << "[NPU] Build model failed!";
return subgraph::FAILED;
return false;
}
auto device_program = std::make_shared<device_program_t>(device_client);
device_program_map_[inputs_shape_] = device_program;
// Query and check the dimensions of valid input and output tensors
std::vector<hiai::TensorDimension> device_idims, device_odims;
if (device_program->client->GetModelIOTensorDim(
model_name_, device_idims, device_odims) != hiai::AI_SUCCESS) {
LOG(WARNING)
<< "[NPU] Get the dimensions of input and output tensors failed!";
return subgraph::FAILED;
// Load the HiAI om model and create a HiAI model manager client(from HiAI
// Service) to run inference.
bool model_comp = true;
model_client_ =
lite::npu::Device::Global().Load(model_name_, &model_buffer, &model_comp);
if (!model_client_) {
LOG(WARNING) << "[NPU] Load model failed!";
return false;
}
device_program->device_idims = device_idims;
device_program->device_odims = device_odims;
CHECK_EQ(device_idims.size(), device_inames_.size());
CHECK_EQ(device_odims.size(), device_onames_.size());
origin_idims_.resize(device_inames_.size());
origin_itensors_.resize(device_inames_.size());
device_itensors_.resize(device_inames_.size());
origin_odims_.resize(device_onames_.size());
origin_otensors_.resize(device_onames_.size());
device_otensors_.resize(device_onames_.size());
for (int i = 0; i < device_inames_.size(); i++) {
auto node = graph.Get(device_inames_[i]);
auto precision = node->precision();
auto layout = node->layout();
origin_itensors_[i] = scope_->FindMutableTensor(device_inames_[i]);
CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[NPU] Inputs[" << i << "] name: " << device_inames_[i]
<< " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " dims: {"
<< device_idims[i].GetNumber() << ","
<< device_idims[i].GetChannel() << ","
<< device_idims[i].GetHeight() << "," << device_idims[i].GetWidth()
<< "}";
// Prepare the device input tensors
CHECK_EQ(origin_idims_[i].production(),
device_idims[i].GetNumber() * device_idims[i].GetChannel() *
device_idims[i].GetHeight() * device_idims[i].GetWidth());
device_itensors_[i].reset(new hiai::AiTensor);
device_itensors_[i]->Init(&(device_idims[i]));
// Update the precison and dimensions of the origin output tensors
CHECK_EQ(origin_otensors.size(), output_names.size());
origin_otypes_.resize(output_names.size());
origin_odims_.resize(output_names.size());
for (size_t i = 0; i < output_names.size(); i++) {
origin_otypes_[i] = graph.Get(output_names[i])->precision();
origin_odims_[i] = origin_otensors[i]->dims().Vectorize();
}
device_program->origin_idims = origin_idims_;
for (int i = 0; i < device_onames_.size(); i++) {
auto node = graph.Get(device_onames_[i]);
auto precision = node->precision();
auto layout = node->layout();
origin_otensors_[i] = scope_->FindMutableTensor(device_onames_[i]);
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[NPU] Outputs[" << i << "] name: " << device_onames_[i]
<< " precision: " << PrecisionToStr(precision)
<< " layout: " << DataLayoutToStr(layout) << " dims: {"
<< device_odims[i].GetNumber() << ","
<< device_odims[i].GetChannel() << ","
<< device_odims[i].GetHeight() << "," << device_odims[i].GetWidth()
<< "}";
// Prepare the device output tensors
switch (precision) {
case PRECISION(kFloat):
origin_otensors_[i]->mutable_data<float>();
break;
case PRECISION(kBool):
origin_otensors_[i]->mutable_data<bool>();
break;
case PRECISION(kInt8):
origin_otensors_[i]->mutable_data<int8_t>();
break;
case PRECISION(kInt16):
origin_otensors_[i]->mutable_data<int16_t>();
break;
case PRECISION(kInt32):
origin_otensors_[i]->mutable_data<int32_t>();
break;
case PRECISION(kInt64):
origin_otensors_[i]->mutable_data<int64_t>();
break;
default:
LOG(FATAL) << "[NPU] " << device_onames_[i]
<< " can't mutable data with precision type "
<< PrecisionToStr(precision);
break;
if (!model_cache_dir.empty()) {
// Save the generated model to file, used for the model caching or the
// offline model generation
auto model_path = model_cache_dir + "/" + model_name_ + ".om";
VLOG(3) << "[NPU] Save model to " << model_path;
if (!WriteFile(model_path, model_buffer)) {
LOG(WARNING) << "[NPU] Open " << model_path << " for writting failed!";
}
// Serialize the precisions and shapes of the origin output tensors into the
// configuration file
std::ostringstream os;
for (int i = 0; i < output_names.size(); i++) {
os << static_cast<int32_t>(origin_otypes_[i]) << ":";
for (auto dim : origin_odims_[i]) {
os << dim << ",";
}
os << ";";
}
auto str = os.str();
std::vector<char> config_buffer(str.begin(), str.end());
auto config_path = model_cache_dir + "/" + model_name_ + ".cfg";
VLOG(3) << "[NPU] Save configuration to " << config_path;
if (!WriteFile(config_path, config_buffer)) {
LOG(WARNING) << "[NPU] Open " << config_path << " for writting failed!";
}
device_program->origin_odims = origin_odims_;
CHECK_EQ(origin_odims_[i].production(),
device_odims[i].GetNumber() * device_odims[i].GetChannel() *
device_odims[i].GetHeight() * device_odims[i].GetWidth());
device_otensors_[i].reset(new hiai::AiTensor);
device_otensors_[i]->Init(&(device_odims[i]));
}
return status;
return true;
}
int SubgraphEngine::LaunchDeviceProgram() {
// Copy the data of origin input tensors to the buffer of input HiAI tensors
// init device_itensors_, device_otensors_, origin_otensors_
auto device_program = device_program_map_[inputs_shape_];
for (size_t i = 0; i < device_itensors_.size(); i++) {
device_itensors_[i]->Init(&(device_program->device_idims[i]));
std::memcpy(device_itensors_[i]->GetBuffer(),
origin_itensors_[i]->raw_data(),
origin_itensors_[i]->memory_size());
bool DeviceProgram::ShareBufferWithOriginTensors(
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
std::vector<Tensor*>* origin_itensors,
std::vector<Tensor*>* origin_otensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_itensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_otensors) {
CHECK(!model_name_.empty() && model_client_);
// Query the dimensions of the device input and output tensors if not
// initialized
if (device_idims_.empty() || device_odims_.empty()) {
if (model_client_->GetModelIOTensorDim(
model_name_, device_idims_, device_odims_) != hiai::AI_SUCCESS) {
LOG(WARNING)
<< "[NPU] Get the dimensions of input and output tensors failed!";
return false;
}
}
for (size_t i = 0; i < device_otensors_.size(); i++) {
device_otensors_[i]->Init(&(device_program->device_odims[i]));
// Check the dimensions of the device tensors and the origin tensors
CHECK_EQ(device_itensors->size(), input_names.size());
CHECK_EQ(device_otensors->size(), output_names.size());
CHECK_EQ(origin_otypes_.size(), output_names.size());
CHECK_EQ(origin_odims_.size(), output_names.size());
CHECK_EQ(device_idims_.size(), input_names.size());
CHECK_EQ(device_odims_.size(), output_names.size());
for (int i = 0; i < input_names.size(); i++) {
VLOG(3) << "[NPU] Inputs[" << i << "] name: " << input_names[i]
<< " origin dims:" << (*origin_itensors)[i]->dims().repr()
<< " device dims: {" << device_idims_[i].GetNumber() << ","
<< device_idims_[i].GetChannel() << ","
<< device_idims_[i].GetHeight() << ","
<< device_idims_[i].GetWidth() << "}";
CHECK_EQ((*origin_itensors)[i]->dims().production(),
device_idims_[i].GetNumber() * device_idims_[i].GetChannel() *
device_idims_[i].GetHeight() * device_idims_[i].GetWidth());
VLOG(3) << "[NPU] Init the input tensors for the device program and share "
"their buffers with the origin input tensors";
// reinit device tensor will free shared buffer, so copy data to a tmp
// tensor
Tensor tmp;
tmp.CopyDataFrom(*(*origin_itensors)[i]);
(*device_itensors)[i]->Init(&(device_idims_[i]));
std::memcpy(
(*device_itensors)[i]->GetBuffer(), tmp.raw_data(), tmp.memory_size());
// Share data buf between device_itensor and origin_itensor
std::shared_ptr<Buffer> buffer =
std::make_shared<Buffer>((*device_itensors)[i]->GetBuffer(),
lite_api::TargetType::kHost,
(*device_itensors)[i]->GetSize());
(*origin_itensors)[i]->ResetBuffer(buffer,
(*device_itensors)[i]->GetSize());
}
for (size_t i = 0; i < origin_otensors_.size(); i++) {
origin_otensors_[i]->Resize(device_program->origin_odims[i]);
for (int i = 0; i < output_names.size(); i++) {
(*origin_otensors)[i]->set_precision(origin_otypes_[i]);
(*origin_otensors)[i]->Resize(origin_odims_[i]);
VLOG(3) << "[NPU] Outputs[" << i << "] name: " << output_names[i]
<< " origin dims:" << (*origin_otensors)[i]->dims().repr()
<< " device dims: {" << device_odims_[i].GetNumber() << ","
<< device_odims_[i].GetChannel() << ","
<< device_odims_[i].GetHeight() << ","
<< device_odims_[i].GetWidth() << "}";
CHECK_EQ((*origin_otensors)[i]->dims().production(),
device_odims_[i].GetNumber() * device_odims_[i].GetChannel() *
device_odims_[i].GetHeight() * device_odims_[i].GetWidth());
(*device_otensors)[i]->Init(&(device_odims_[i]));
VLOG(3) << "[NPU] Init the output tensors for the device program and share "
"their buffers with the origin output tensors";
// Share data buf between device_itensor and origin_itensor
std::shared_ptr<Buffer> buffer =
std::make_shared<Buffer>((*device_otensors)[i]->GetBuffer(),
lite_api::TargetType::kHost,
(*device_otensors)[i]->GetSize());
(*origin_otensors)[i]->ResetBuffer(buffer,
(*device_otensors)[i]->GetSize());
}
return true;
}
bool DeviceProgram::ZeroCopyRun(
std::vector<std::shared_ptr<hiai::AiTensor>>* device_itensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_otensors) {
CHECK(!model_name_.empty() && model_client_);
// Run the HiAI model by name
std::string key = "model_name"; // Note: key seems must be model_name
hiai::AiContext model_context;
......@@ -219,30 +305,87 @@ int SubgraphEngine::LaunchDeviceProgram() {
};
int istamp;
auto start_time = GetCurrentUS();
CHECK_EQ(device_program->client->Process(
model_context, device_itensors_, device_otensors_, 1000, istamp),
CHECK_EQ(model_client_->Process(
model_context, *device_itensors, *device_otensors, 1000, istamp),
hiai::AI_SUCCESS);
VLOG(3) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us";
return true;
}
bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() {
// Obtain the origin input tensors, and create the origin output
// tensors(Don't try to access them before launch the device program or the
// origin program)
PrepareWorkspaceForOriginProgram();
// Create the device input and output tensors, but don't initialize them
// with the dimensions
device_itensors_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) {
device_itensors_[i].reset(new hiai::AiTensor);
CHECK(device_itensors_[i]);
}
device_otensors_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) {
device_otensors_[i].reset(new hiai::AiTensor);
CHECK(device_otensors_[i]);
}
return true;
}
// Copy the data of output HiAI tensor to the buffer of origin output tensors
for (size_t i = 0; i < device_otensors_.size(); i++) {
std::memcpy(const_cast<void*>(origin_otensors_[i]->raw_data()),
device_otensors_[i]->GetBuffer(),
device_otensors_[i]->GetSize());
bool SubgraphEngine::BuildDeviceProgram() {
// Check if the cache device program exists
if (!device_programs_.count(origin_idims_)) {
auto device_program = std::make_shared<DeviceProgram>();
// Obtain the model cache dir from the NPU Context of the subgraph op
auto model_cache_dir = ctx_->As<NPUContext>().SubgraphModelCacheDir();
VLOG(3) << "[NPU] Getting subgraph model_cache_dir is: " << model_cache_dir;
// Check and load if the cached model and configuration file exists
if (model_cache_dir.empty() ||
!device_program->LoadFromCacheFile(
input_names_, output_names_, origin_idims_, model_cache_dir)) {
// Build the model online, including converting the paddle ops to the HiAI
// IR nodes, building the HiAI IR graph to the om model, then load it as a
// new HiAI model manager client for inference.
if (origin_program_.empty()) {
BuildOriginProgram();
}
CHECK(!origin_program_.empty()) << "no instructions";
if (!device_program->BuildGraphAndCacheToFile(origin_program_,
input_names_,
output_names_,
origin_idims_,
origin_otensors_,
model_cache_dir)) {
return false;
}
}
if (device_program->model_client_ == nullptr) {
return false;
}
device_programs_[origin_idims_] = device_program;
}
return 0;
auto device_program = device_programs_[origin_idims_];
CHECK(device_program && device_program->model_client_);
return device_program->ShareBufferWithOriginTensors(input_names_,
output_names_,
&origin_itensors_,
&origin_otensors_,
&device_itensors_,
&device_otensors_);
}
bool SubgraphEngine::InputShapeChanged() {
std::vector<std::vector<int64_t>> new_shape;
for (auto origin_itensor : origin_itensors_) {
new_shape.push_back(origin_itensor->dims().Vectorize());
bool SubgraphEngine::LaunchDeviceProgram() {
// Roll back to launch the origin program if the device program can't be
// found or the model client isn't initialized.
if (device_programs_.count(origin_idims_) == 0 ||
device_programs_[origin_idims_]->model_client_ == nullptr) {
return LaunchOriginProgram();
}
inputs_shape_ = new_shape;
if (device_program_map_.count(inputs_shape_) > 0) {
return false;
auto device_program = device_programs_[origin_idims_];
if (!device_program->model_client_) {
return LaunchOriginProgram();
}
return true;
return device_program->ZeroCopyRun(&device_itensors_, &device_otensors_);
}
void SubgraphCompute::PrepareForRun() {
......@@ -254,12 +397,11 @@ void SubgraphCompute::PrepareForRun() {
param.output_data_names,
param.scope));
CHECK(engine_);
engine_->Build();
}
void SubgraphCompute::Run() {
CHECK(engine_);
engine_->Launch();
engine_->Run();
}
} // namespace npu
......
......@@ -28,40 +28,65 @@ namespace lite {
namespace kernels {
namespace npu {
class DeviceProgram {
public:
DeviceProgram() {}
~DeviceProgram() {}
std::string GenerateModelName(
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims);
bool LoadFromCacheFile(const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims,
const std::string& model_cache_dir);
bool BuildGraphAndCacheToFile(
const std::vector<Instruction>& origin_program,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims,
const std::vector<Tensor*>& origin_otensors,
const std::string& model_cache_dir);
bool ShareBufferWithOriginTensors(
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
std::vector<Tensor*>* origin_itensors,
std::vector<Tensor*>* origin_otensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_itensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_otensors);
bool ZeroCopyRun(
std::vector<std::shared_ptr<hiai::AiTensor>>* device_itensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_otensors);
public:
std::string model_name_{""};
std::shared_ptr<hiai::AiModelMngerClient> model_client_{nullptr};
std::vector<std::vector<int64_t>> origin_odims_;
std::vector<PrecisionType> origin_otypes_;
std::vector<hiai::TensorDimension> device_idims_{};
std::vector<hiai::TensorDimension> device_odims_{};
};
class SubgraphEngine : public subgraph::Engine {
public:
SubgraphEngine(KernelContext *ctx,
SubgraphEngine(KernelContext* ctx,
int block_idx,
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
Scope *scope)
cpp::BlockDesc* block_desc,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
Scope* scope)
: subgraph::Engine(
ctx, block_idx, block_desc, input_names, output_names, scope) {}
struct device_program_t {
explicit device_program_t(std::shared_ptr<hiai::AiModelMngerClient> _client)
: client(_client) {}
std::shared_ptr<hiai::AiModelMngerClient> client{nullptr};
std::vector<DDim> origin_idims{};
std::vector<DDim> origin_odims{};
std::vector<hiai::TensorDimension> device_idims{};
std::vector<hiai::TensorDimension> device_odims{};
};
protected:
int BuildDeviceProgram() override;
int LaunchDeviceProgram() override;
bool InputShapeChanged() override;
bool PrepareWorkspaceForDeviceProgram() override;
bool BuildDeviceProgram() override;
bool LaunchDeviceProgram() override;
std::string model_name_{"model.om"};
std::vector<std::vector<int64_t>> inputs_shape_{};
std::map<std::vector<std::vector<int64_t>>, std::shared_ptr<device_program_t>>
device_program_map_{};
std::vector<std::string> device_inames_{};
std::vector<std::string> device_onames_{};
std::vector<std::shared_ptr<hiai::AiTensor>> device_itensors_{};
std::vector<std::shared_ptr<hiai::AiTensor>> device_otensors_{};
std::map<std::vector<std::vector<int64_t>>, std::shared_ptr<DeviceProgram>>
device_programs_;
};
class SubgraphCompute : public KernelLite<TARGET(kNPU), PRECISION(kAny)> {
......
......@@ -22,6 +22,8 @@
#define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \
"SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE"
#define SUBGRAPH_DISABLE_ONLINE_MODE "SUBGRAPH_DISABLE_ONLINE_MODE"
namespace paddle {
namespace lite {
......
......@@ -119,5 +119,40 @@ static std::vector<std::string> ListDir(const std::string& path,
return paths;
}
static bool ReadFile(const std::string& filename, std::vector<char>* contents) {
FILE* fp = fopen(filename.c_str(), "rb");
if (!fp) return false;
fseek(fp, 0, SEEK_END);
size_t size = ftell(fp);
fseek(fp, 0, SEEK_SET);
contents->clear();
contents->resize(size);
size_t offset = 0;
char* ptr = reinterpret_cast<char*>(&(contents->at(0)));
while (offset < size) {
size_t already_read = fread(ptr, 1, size - offset, fp);
offset += already_read;
ptr += already_read;
}
fclose(fp);
return true;
}
static bool WriteFile(const std::string& filename,
const std::vector<char>& contents) {
FILE* fp = fopen(filename.c_str(), "wb");
if (!fp) return false;
size_t size = contents.size();
size_t offset = 0;
const char* ptr = reinterpret_cast<const char*>(&(contents.at(0)));
while (offset < size) {
size_t already_written = fwrite(ptr, 1, size - offset, fp);
offset += already_written;
ptr += already_written;
}
fclose(fp);
return true;
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace lite {
std::string MD5(std::string message) {
const uint32_t shiftAmounts[] = {
7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22,
5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20,
4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23,
6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21};
const uint32_t partsOfSines[] = {
0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a,
0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be,
0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340,
0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8,
0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8,
0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c,
0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa,
0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665,
0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92,
0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1,
0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391};
uint32_t state[4];
state[0] = 0x67452301;
state[1] = 0xefcdab89;
state[2] = 0x98badcfe;
state[3] = 0x10325476;
// Pad with zeros
int size = ((((message.length() + 8) / 64) + 1) * 64) - 8;
uint8_t *buf = reinterpret_cast<uint8_t *>(calloc(size + 64, 1));
memcpy(buf, message.c_str(), message.length());
buf[message.length()] = 128;
uint32_t bits = 8 * message.length();
memcpy(buf + size, &bits, 4);
// Process at each 512-bit(64 bytes) chunk
#define LEFTROTATE(x, c) (((x) << (c)) | ((x) >> (32 - (c))))
for (int offset = 0; offset < size; offset += 64) {
uint32_t A = state[0];
uint32_t B = state[1];
uint32_t C = state[2];
uint32_t D = state[3];
uint32_t *W = reinterpret_cast<uint32_t *>(buf + offset);
for (uint32_t i = 0; i < 64; i++) {
uint32_t F, g;
if (i < 16) {
F = (B & C) | ((~B) & D);
g = i;
} else if (i < 32) {
F = (D & B) | ((~D) & C);
g = (5 * i + 1) % 16;
} else if (i < 48) {
F = B ^ C ^ D;
g = (3 * i + 5) % 16;
} else {
F = C ^ (B | (~D));
g = (7 * i) % 16;
}
uint32_t T = D;
D = C;
C = B;
B = B + LEFTROTATE((A + F + partsOfSines[i] + W[g]), shiftAmounts[i]);
A = T;
}
state[0] += A;
state[1] += B;
state[2] += C;
state[3] += D;
}
#undef LEFTROTATE
free(buf);
// Convert digest to string
std::string res;
res.reserve(16 << 1);
const uint8_t *digest = reinterpret_cast<uint8_t *>(state);
char hex[3];
for (size_t i = 0; i < 16; i++) {
snprintf(hex, sizeof(hex), "%02x", digest[i]);
res.append(hex);
}
return res;
}
} // namespace lite
} // namespace paddle
......@@ -60,6 +60,38 @@ static std::string to_string(const T& v) {
return ss.str();
}
static std::string to_string(int index) {
const int BUFFER_LENGTH = 15;
char buffer[BUFFER_LENGTH];
snprintf(buffer, sizeof(buffer), "%d", index);
return std::string(buffer);
}
template <typename T = std::string>
static T parse_string(const std::string& v) {
return v;
}
template <>
int32_t parse_string<int32_t>(const std::string& v) {
return std::stoi(v);
}
template <>
int64_t parse_string<int64_t>(const std::string& v) {
return std::stoll(v);
}
template <>
float parse_string<float>(const std::string& v) {
return std::stof(v);
}
template <>
double parse_string<double>(const std::string& v) {
return std::stod(v);
}
template <typename T>
std::string Join(const std::vector<T>& vec, const std::string& delim) {
if (vec.empty()) return "";
......@@ -84,19 +116,20 @@ static std::string Repr(const std::vector<std::string>& v) {
return "{" + Join(tmp, ",") + "}";
}
static std::vector<std::string> Split(const std::string& original,
const std::string& separator) {
std::vector<std::string> results;
template <class T = std::string>
static std::vector<T> Split(const std::string& original,
const std::string& separator) {
std::vector<T> results;
std::string::size_type pos1, pos2;
pos2 = original.find(separator);
pos1 = 0;
while (std::string::npos != pos2) {
results.push_back(original.substr(pos1, pos2 - pos1));
results.push_back(parse_string<T>(original.substr(pos1, pos2 - pos1)));
pos1 = pos2 + separator.size();
pos2 = original.find(separator, pos1);
}
if (pos1 != original.length()) {
results.push_back(original.substr(pos1));
results.push_back(parse_string<T>(original.substr(pos1)));
}
return results;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册