未验证 提交 ac897177 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] enhance cache offline model, test=develop (#3805) (#3931)

* [NPU] enhance cache offline model, test=develop
上级 813f17ba
...@@ -117,3 +117,6 @@ metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models ...@@ -117,3 +117,6 @@ metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models
metal/MobileNetDemo/MobileNetDemo/Resources metal/MobileNetDemo/MobileNetDemo/Resources
build* build*
# hiai libs
ai_ddk_lib*
...@@ -35,7 +35,11 @@ endif() ...@@ -35,7 +35,11 @@ endif()
if(NOT DEFINED ANDROID_API_LEVEL) if(NOT DEFINED ANDROID_API_LEVEL)
set(ANDROID_API_LEVEL "23") set(ANDROID_API_LEVEL "23")
if(ARM_TARGET_ARCH_ABI STREQUAL "armv7") if(ARM_TARGET_ARCH_ABI STREQUAL "armv7")
set(ANDROID_API_LEVEL "22") if(LITE_WITH_NPU AND NOT LITE_ON_TINY_PUBLISH)
set(ANDROID_API_LEVEL "24") # HIAI DDK depends on android-24
else()
set(ANDROID_API_LEVEL "22")
endif()
endif() endif()
endif() endif()
......
...@@ -70,6 +70,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -70,6 +70,10 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
raw_predictor_.Build(config, places, passes); raw_predictor_.Build(config, places, passes);
mode_ = config.power_mode(); mode_ = config.power_mode();
threads_ = config.threads(); threads_ = config.threads();
#ifdef LITE_WITH_NPU
Context<TargetType::kNPU>::SetSubgraphModelCacheDir(
config.subgraph_model_cache_dir());
#endif
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ #if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL) && !defined(__APPLE__) !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) && !defined(__APPLE__)
int num_threads = config.x86_math_library_num_threads(); int num_threads = config.x86_math_library_num_threads();
......
...@@ -36,6 +36,11 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { ...@@ -36,6 +36,11 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) {
} }
mode_ = config.power_mode(); mode_ = config.power_mode();
threads_ = config.threads(); threads_ = config.threads();
#ifdef LITE_WITH_NPU
Context<TargetType::kNPU>::SetSubgraphModelCacheDir(
config.subgraph_model_cache_dir());
#endif
} }
std::unique_ptr<lite_api::Tensor> LightPredictorImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> LightPredictorImpl::GetInput(int i) {
......
...@@ -118,6 +118,8 @@ class LITE_API ConfigBase { ...@@ -118,6 +118,8 @@ class LITE_API ConfigBase {
std::string model_dir_; std::string model_dir_;
int threads_{1}; int threads_{1};
PowerMode mode_{LITE_POWER_NO_BIND}; PowerMode mode_{LITE_POWER_NO_BIND};
// to save subgraph model for npu/xpu/...
std::string subgraph_model_cache_dir_{""};
public: public:
explicit ConfigBase(PowerMode mode = LITE_POWER_NO_BIND, int threads = 1); explicit ConfigBase(PowerMode mode = LITE_POWER_NO_BIND, int threads = 1);
...@@ -130,6 +132,13 @@ class LITE_API ConfigBase { ...@@ -130,6 +132,13 @@ class LITE_API ConfigBase {
// set Thread // set Thread
void set_threads(int threads); void set_threads(int threads);
int threads() const { return threads_; } int threads() const { return threads_; }
// set subgraph_model_dir
void set_subgraph_model_cache_dir(std::string subgraph_model_cache_dir) {
subgraph_model_cache_dir_ = subgraph_model_cache_dir;
}
const std::string& subgraph_model_cache_dir() const {
return subgraph_model_cache_dir_;
}
}; };
/// CxxConfig is the config for the Full feature predictor. /// CxxConfig is the config for the Full feature predictor.
......
...@@ -19,52 +19,122 @@ namespace paddle { ...@@ -19,52 +19,122 @@ namespace paddle {
namespace lite { namespace lite {
namespace npu { namespace npu {
std::shared_ptr<hiai::AiModelMngerClient> Device::Build( std::shared_ptr<hiai::AiModelMngerClient> Device::Load(
const std::string model_name, // NOLINT const std::string& model_name,
std::vector<ge::Operator>& input_nodes, // NOLINT std::vector<char>* model_buffer,
std::vector<ge::Operator>& output_nodes // NOLINT bool* model_comp) {
) {
VLOG(3) << "[NPU] Build model";
// Build the HiAI IR graph to the HiAI om model
ge::Graph ir_graph("graph");
ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
ge::Model om_model("model", "model");
om_model.SetGraph(ir_graph);
domi::HiaiIrBuild ir_build;
domi::ModelBufferData om_model_buf;
if (!ir_build.CreateModelBuff(om_model, om_model_buf)) {
LOG(WARNING) << "[NPU] CreateModelBuff failed!";
return nullptr;
}
if (!ir_build.BuildIRModel(om_model, om_model_buf)) {
LOG(WARNING) << "[NPU] BuildIRModel failed!";
ir_build.ReleaseModelBuff(om_model_buf);
return nullptr;
}
// Create a HiAI model manager client to load the HiAI om model // Create a HiAI model manager client to load the HiAI om model
std::shared_ptr<hiai::AiModelMngerClient> model_client( auto model_client = std::make_shared<hiai::AiModelMngerClient>();
new hiai::AiModelMngerClient());
if (model_client->Init(nullptr) != hiai::AI_SUCCESS) { if (model_client->Init(nullptr) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient init failed)!"; LOG(WARNING) << "[NPU] Init hiai model client failed!";
ir_build.ReleaseModelBuff(om_model_buf);
return nullptr; return nullptr;
} }
// Check HiAI DDK version
const char* ddk_version = model_client->GetVersion();
if (ddk_version) {
LOG(INFO) << "[NPU] HiAI DDK version: " << ddk_version;
} else {
LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!";
}
// Check model compatibility
auto model_desc = std::make_shared<hiai::AiModelDescription>( auto model_desc = std::make_shared<hiai::AiModelDescription>(
model_name, freq_level(), framework_type(), model_type(), device_type()); model_name, freq_level(), framework_type(), model_type(), device_type());
model_desc->SetModelBuffer(om_model_buf.data, om_model_buf.length); model_desc->SetModelBuffer(
std::vector<std::shared_ptr<hiai::AiModelDescription>> model_descs; reinterpret_cast<const void*>(model_buffer->data()),
model_descs.push_back(model_desc); model_buffer->size());
if (!*model_comp &&
model_client->CheckModelCompatibility(*model_desc, *model_comp) !=
hiai::AI_SUCCESS) {
*model_comp = false;
VLOG(3) << "[NPU] model is NOT compatiblitiable, setting model_comp to "
<< *model_comp;
} else {
*model_comp = true;
VLOG(3) << "[NPU] model is compatiblitiable, setting model_comp to "
<< *model_comp;
}
// Rebuild and write the data of the compatible model to the model buffer
if (!*model_comp) {
std::shared_ptr<hiai::AiModelBuilder> model_builder =
std::make_shared<hiai::AiModelBuilder>(model_client);
hiai::MemBuffer* org_model_buffer = model_builder->InputMemBufferCreate(
reinterpret_cast<void*>(model_buffer->data()), model_buffer->size());
if (org_model_buffer) {
std::vector<hiai::MemBuffer*> org_model_buffers;
org_model_buffers.push_back(org_model_buffer);
hiai::MemBuffer* new_model_buffer = model_builder->OutputMemBufferCreate(
framework_type(), org_model_buffers);
// VLOG(3) << "[NPU] new model buffer memeory size is " <<
// new_model_buffer->GetMemBufferSize();
if (new_model_buffer) {
uint32_t new_model_size = 0;
if (model_builder->BuildModel(org_model_buffers,
new_model_buffer,
new_model_size) == hiai::AI_SUCCESS) {
// need to change to new_model_size as GetMemBufferSize is not
// correct.
model_buffer->resize(new_model_size);
memcpy(reinterpret_cast<void*>(model_buffer->data()),
new_model_buffer->GetMemBufferData(),
new_model_size);
// Reset the model buffer
model_desc->SetModelBuffer(
reinterpret_cast<const void*>(model_buffer->data()),
model_buffer->size());
VLOG(3) << "[NPU] Rebuild the compatible model done.";
} else {
LOG(WARNING) << "[NPU] Rebuild the compatible model failed!";
}
model_builder->MemBufferDestroy(new_model_buffer);
} else {
LOG(WARNING) << "[NPU] OutputMemBufferCreate failed!";
}
model_builder->MemBufferDestroy(org_model_buffer);
} else {
LOG(WARNING) << "[NPU] InputMemBufferCreate failed!";
}
}
// Load the compatible model
std::vector<std::shared_ptr<hiai::AiModelDescription>> model_descs{
model_desc};
if (model_client->Load(model_descs) != hiai::AI_SUCCESS) { if (model_client->Load(model_descs) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!"; LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!";
ir_build.ReleaseModelBuff(om_model_buf);
return nullptr; return nullptr;
} }
ir_build.ReleaseModelBuff(om_model_buf); VLOG(3) << "[NPU] Load model done.";
VLOG(3) << "[NPU] Build done";
return model_client; return model_client;
} }
bool Device::Build(std::vector<ge::Operator>& input_nodes, // NOLINT
std::vector<ge::Operator>& output_nodes, // NOLINT
std::vector<char>* model_buffer) {
// Convert the HiAI IR graph to the HiAI om model
ge::Graph ir_graph("graph");
ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
ge::Model om_model("model", "model");
om_model.SetGraph(ir_graph);
// Build the HiAI om model, serialize and output it to the om buffer
domi::HiaiIrBuild ir_build;
domi::ModelBufferData om_buffer;
if (!ir_build.CreateModelBuff(om_model, om_buffer)) {
LOG(WARNING) << "[NPU] CreateModelBuff failed!";
return false;
}
if (!ir_build.BuildIRModel(om_model, om_buffer)) {
LOG(WARNING) << "[NPU] BuildIRModel failed!";
ir_build.ReleaseModelBuff(om_buffer);
return false;
}
model_buffer->resize(om_buffer.length);
memcpy(reinterpret_cast<void*>(model_buffer->data()),
reinterpret_cast<void*>(om_buffer.data),
om_buffer.length);
ir_build.ReleaseModelBuff(om_buffer);
VLOG(3) << "[NPU] Build model done.";
return true;
}
} // namespace npu } // namespace npu
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -38,13 +38,18 @@ class Device { ...@@ -38,13 +38,18 @@ class Device {
int model_type() { return model_type_; } int model_type() { return model_type_; }
int device_type() { return device_type_; } int device_type() { return device_type_; }
// Load the HiAI om model from buffer, rebuild the model if it's incompatible
// with the current device, then create a HiAI model manager client(from HiAI
// Server) to run inference
std::shared_ptr<hiai::AiModelMngerClient> Load(
const std::string& model_name,
std::vector<char>* model_buffer,
bool* model_comp);
// Build the HiAI IR graph to om model, return HiAI model manager client to // Build the HiAI IR graph to om model, return HiAI model manager client to
// load om model and run inference. // load om model and run inference.
std::shared_ptr<hiai::AiModelMngerClient> Build( bool Build(std::vector<ge::Operator>& input_nodes, // NOLINT
const std::string model_name, // NOLINT std::vector<ge::Operator>& output_nodes, // NOLINT
std::vector<ge::Operator>& input_nodes, // NOLINT std::vector<char>* model_buffer);
std::vector<ge::Operator>& output_nodes // NOLINT
); // NOLINT
private: private:
int freq_level_{3}; int freq_level_{3};
......
...@@ -17,6 +17,10 @@ ...@@ -17,6 +17,10 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
#ifdef LITE_WITH_NPU
std::string Context<TargetType::kNPU>::subgraph_model_cache_dir_{""}; // NOLINT
#endif
#ifdef LITE_WITH_XPU #ifdef LITE_WITH_XPU
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr}; thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0}; int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0};
......
...@@ -85,6 +85,16 @@ class Context<TargetType::kNPU> { ...@@ -85,6 +85,16 @@ class Context<TargetType::kNPU> {
NPUContext& operator=(const NPUContext& ctx) {} NPUContext& operator=(const NPUContext& ctx) {}
std::string name() const { return "NPUContext"; } std::string name() const { return "NPUContext"; }
static void SetSubgraphModelCacheDir(std::string subgraph_model_cache_dir) {
subgraph_model_cache_dir_ = subgraph_model_cache_dir;
}
static std::string SubgraphModelCacheDir() {
return subgraph_model_cache_dir_;
}
private:
static std::string subgraph_model_cache_dir_;
}; };
#endif #endif
......
...@@ -426,73 +426,51 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -426,73 +426,51 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
subgraph_op_desc.SetAttr<int32_t>("sub_block", sub_block_idx); subgraph_op_desc.SetAttr<int32_t>("sub_block", sub_block_idx);
// Extract input and output nodes from the target subgraph // Extract input and output nodes from the target subgraph
std::unordered_set<Node *> input_var_nodes; std::unordered_set<Node *> idata_var_nodes;
std::unordered_set<Node *> weight_var_nodes; std::unordered_set<Node *> weight_var_nodes;
std::unordered_set<Node *> output_var_nodes; std::unordered_set<Node *> odata_var_nodes;
std::unordered_set<Node *> local_var_nodes; std::unordered_set<Node *> local_var_nodes;
std::unordered_set<Node *> unused_var_nodes; std::unordered_set<Node *> unused_var_nodes;
ExtractInputsOutputs(subgraph_nodes, ExtractInputsOutputs(subgraph_nodes,
&input_var_nodes, &idata_var_nodes,
&weight_var_nodes, &weight_var_nodes,
&output_var_nodes, &odata_var_nodes,
&local_var_nodes, &local_var_nodes,
&unused_var_nodes); &unused_var_nodes);
// A simplified model without the original weight/local/unused nodes on the
// subgraph ops will be saved only if 'SUBGRAPH_DISABLE_ONLINE_MODE' is set to
// true and Predictor->Run(...), Predictor->Save(...) is called.
std::unordered_set<Node *> input_var_nodes(idata_var_nodes.begin(),
idata_var_nodes.end());
std::unordered_set<Node *> output_var_nodes(odata_var_nodes.begin(),
odata_var_nodes.end());
if (!GetBoolFromEnv(SUBGRAPH_DISABLE_ONLINE_MODE)) {
input_var_nodes.insert(weight_var_nodes.begin(), weight_var_nodes.end());
output_var_nodes.insert(local_var_nodes.begin(), local_var_nodes.end());
output_var_nodes.insert(unused_var_nodes.begin(), unused_var_nodes.end());
}
// Set input and output name mapping which stores the real inputs and // Set input and output name mapping which stores the real inputs and
// outputs // outputs
std::vector<std::string> input_var_names; std::vector<std::string> idata_var_names;
std::vector<std::string> output_var_names; std::vector<std::string> odata_var_names;
for (auto &var_node : input_var_nodes) { for (auto &var_node : idata_var_nodes) {
input_var_names.push_back(var_node->AsArg().name); idata_var_names.push_back(var_node->AsArg().name);
} }
for (auto &var_node : output_var_nodes) { for (auto &var_node : odata_var_nodes) {
output_var_names.push_back(var_node->AsArg().name); odata_var_names.push_back(var_node->AsArg().name);
} }
subgraph_op_desc.SetAttr<std::vector<std::string>>("input_data_names", subgraph_op_desc.SetAttr<std::vector<std::string>>("input_data_names",
input_var_names); idata_var_names);
subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names", subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names",
output_var_names); odata_var_names);
// Set input/output scale values of input/output var nodes for
// type_precision_cast_pass.
std::vector<float> input_data_scales;
std::vector<float> output_data_scales;
for (auto &var_node : input_var_nodes) {
auto any_op_node = var_node->outlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("input_scale")) {
input_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("input_scale"));
}
}
for (auto &var_node : output_var_nodes) {
auto any_op_node = var_node->inlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("output_scale")) {
output_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("output_scale"));
}
}
if (input_data_scales.size() > 0) {
subgraph_op_desc.SetAttr<std::vector<float>>("input_data_scales",
input_data_scales);
}
if (output_data_scales.size() > 0) {
subgraph_op_desc.SetAttr<std::vector<float>>("output_data_scales",
output_data_scales);
}
// Set all of the inputs and outputs to the target subgraph op // Set all of the inputs and outputs to the target subgraph op
// To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram() // To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram()
for (auto &var_node : weight_var_nodes) { std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (auto &var_node : input_var_nodes) {
input_var_names.push_back(var_node->AsArg().name); input_var_names.push_back(var_node->AsArg().name);
} }
for (auto &var_node : local_var_nodes) { for (auto &var_node : output_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : unused_var_nodes) {
output_var_names.push_back(var_node->AsArg().name); output_var_names.push_back(var_node->AsArg().name);
} }
subgraph_op_desc.SetInput("Inputs", input_var_names); subgraph_op_desc.SetInput("Inputs", input_var_names);
...@@ -509,26 +487,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -509,26 +487,13 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
for (auto &var_node : input_var_nodes) { for (auto &var_node : input_var_nodes) {
IR_NODE_LINK_TO(var_node, subgraph_op_node); IR_NODE_LINK_TO(var_node, subgraph_op_node);
} }
for (auto &var_node : weight_var_nodes) {
IR_NODE_LINK_TO(var_node, subgraph_op_node);
}
for (auto &var_node : output_var_nodes) { for (auto &var_node : output_var_nodes) {
IR_OP_VAR_LINK(subgraph_op_node, var_node); IR_OP_VAR_LINK(subgraph_op_node, var_node);
} }
for (auto &var_node : local_var_nodes) {
IR_OP_VAR_LINK(subgraph_op_node, var_node);
}
for (auto &var_node : unused_var_nodes) {
IR_OP_VAR_LINK(subgraph_op_node, var_node);
}
// Remove subgraph nodes and unused var nodes // Remove subgraph nodes and unused var nodes
auto nodes2rm = GetNodes2RM(subgraph_nodes, auto nodes2rm =
{input_var_nodes, GetNodes2RM(subgraph_nodes, {input_var_nodes, output_var_nodes});
weight_var_nodes,
output_var_nodes,
local_var_nodes,
unused_var_nodes});
GraphSafeRemoveNodes(graph, nodes2rm); GraphSafeRemoveNodes(graph, nodes2rm);
} }
...@@ -603,7 +568,17 @@ std::unordered_set<const Node *> GetNodes2RM( ...@@ -603,7 +568,17 @@ std::unordered_set<const Node *> GetNodes2RM(
std::unordered_set<const Node *> nodes2rm(op_nodes.begin(), op_nodes.end()); std::unordered_set<const Node *> nodes2rm(op_nodes.begin(), op_nodes.end());
for (auto &op_node : op_nodes) { for (auto &op_node : op_nodes) {
for (auto &var_node : op_node->inlinks) { for (auto &var_node : op_node->inlinks) {
if (!nodes2rm.count(var_node)) { bool skip = false;
// skip the var node which is used by any other ops that doesn't belong to
// the subgraph ops.
for (auto &out_op_node : var_node->outlinks) {
if (std::find(op_nodes.begin(), op_nodes.end(), out_op_node) !=
op_nodes.end()) {
skip = true;
break;
}
}
if (!skip && !nodes2rm.count(var_node)) {
nodes2rm.insert(var_node); nodes2rm.insert(var_node);
} }
} }
......
...@@ -25,6 +25,7 @@ DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model"); ...@@ -25,6 +25,7 @@ DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shape of input tensors"); DEFINE_string(input_tensor_shape, "1,3,224,224", "shape of input tensors");
DEFINE_string(input_tensor_type, "float32", "data type of input tensors"); DEFINE_string(input_tensor_type, "float32", "data type of input tensors");
DEFINE_string(output_tensor_type, "float32", "data type of output tensors"); DEFINE_string(output_tensor_type, "float32", "data type of output tensors");
DEFINE_string(subgraph_model_cache_dir, "", "dir of subgraph model cache");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -132,6 +133,7 @@ std::shared_ptr<lite_api::PaddlePredictor> TestModel( ...@@ -132,6 +133,7 @@ std::shared_ptr<lite_api::PaddlePredictor> TestModel(
mobile_config.set_model_from_file(optimized_model_dir + ".nb"); mobile_config.set_model_from_file(optimized_model_dir + ".nb");
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1); mobile_config.set_threads(1);
mobile_config.set_subgraph_model_cache_dir(FLAGS_subgraph_model_cache_dir);
predictor = lite_api::CreatePaddlePredictor(mobile_config); predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensors(predictor, input_tensor_shape, input_tensor_type, 1); FillInputTensors(predictor, input_tensor_shape, input_tensor_type, 1);
// Run optimized model // Run optimized model
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "lite/kernels/npu/bridges/engine.h" #include "lite/kernels/npu/bridges/engine.h"
#include <sys/time.h> #include <sys/time.h>
#include <time.h> #include <time.h>
#include <algorithm>
#include <utility> #include <utility>
#include "lite/kernels/npu/bridges/registry.h" #include "lite/kernels/npu/bridges/registry.h"
...@@ -22,11 +23,50 @@ namespace paddle { ...@@ -22,11 +23,50 @@ namespace paddle {
namespace lite { namespace lite {
namespace subgraph { namespace subgraph {
int Engine::BuildDeviceProgram() { return FAILED; } Engine::Engine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
lite::Scope *scope)
: ctx_(ctx), block_idx_(block_idx), block_desc_(block_desc), scope_(scope) {
input_names_ = input_names;
output_names_ = output_names;
// Sort the name of input and output tensors, it's convenient for us to get
// the info of input and output tensors in the same order from the device
// program, because the result of subgraph division may be different but right
// at each call of the subgraph pass.
std::stable_sort(input_names_.begin(), input_names_.end());
std::stable_sort(output_names_.begin(), output_names_.end());
}
bool Engine::Run() {
if (is_first_epoch_) {
PrepareWorkspaceForDeviceProgram();
is_first_epoch_ = false;
}
if (InputShapeChanged()) {
BuildDeviceProgram();
}
return LaunchDeviceProgram();
}
int Engine::LaunchDeviceProgram() { return 0; } bool Engine::PrepareWorkspaceForOriginProgram() {
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]);
CHECK(origin_itensors_[i]);
}
origin_otensors_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]);
CHECK(origin_otensors_[i]);
}
return true;
}
int Engine::BuildOriginProgram() { bool Engine::BuildOriginProgram() {
// TODO(hong19860320) The block_desc need to be divided into subgraphs during // TODO(hong19860320) The block_desc need to be divided into subgraphs during
// the exection time. But only see them as a subgraph now. // the exection time. But only see them as a subgraph now.
origin_program_.clear(); origin_program_.clear();
...@@ -34,11 +74,14 @@ int Engine::BuildOriginProgram() { ...@@ -34,11 +74,14 @@ int Engine::BuildOriginProgram() {
auto op_desc = block_desc_->GetOp<cpp::OpDesc>(op_idx); auto op_desc = block_desc_->GetOp<cpp::OpDesc>(op_idx);
CHECK(op_desc); CHECK(op_desc);
std::string op_type = op_desc->Type(); std::string op_type = op_desc->Type();
// Create op and pick up the best kernel
auto op = LiteOpRegistry::Global().Create(op_desc->Type()); auto op = LiteOpRegistry::Global().Create(op_desc->Type());
CHECK(op) << "no Op found for " << op_type;
op->Attach(*op_desc, scope_); op->Attach(*op_desc, scope_);
std::unique_ptr<KernelBase> picked_kernel; std::unique_ptr<KernelBase> picked_kernel;
if (op_desc->HasAttr(kKernelTypeAttr)) { if (op_desc->HasAttr(kKernelTypeAttr)) {
// Create op and pick up kernel according to the kKernelTypeAttr attribute // Create op and pick up the best kernel according to the
// kKernelTypeAttr attribute
auto kernel_type = op_desc->GetAttr<std::string>(kKernelTypeAttr); auto kernel_type = op_desc->GetAttr<std::string>(kKernelTypeAttr);
std::string alias; std::string alias;
Place place; Place place;
...@@ -48,12 +91,14 @@ int Engine::BuildOriginProgram() { ...@@ -48,12 +91,14 @@ int Engine::BuildOriginProgram() {
auto kernels = op->CreateKernels({place}); auto kernels = op->CreateKernels({place});
CHECK_GT(kernels.size(), 0u) << "No kernels found for " << op_type; CHECK_GT(kernels.size(), 0u) << "No kernels found for " << op_type;
auto it = std::find_if( auto it = std::find_if(
kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase>& it) { kernels.begin(), kernels.end(), [&](std::unique_ptr<KernelBase> &it) {
return it->alias() == alias; return it->alias() == alias;
}); });
CHECK(it != kernels.end()); CHECK(it != kernels.end());
picked_kernel = std::move(*it); picked_kernel = std::move(*it);
} else { } else {
// TODO(hong19860320) add kernel picking according to the type of input
// and output tensors
VLOG(3) << "The attr '" << kKernelTypeAttr VLOG(3) << "The attr '" << kKernelTypeAttr
<< "' not found, pick the first kernel for " << op_type; << "' not found, pick the first kernel for " << op_type;
std::vector<std::unique_ptr<KernelBase>> kernels; std::vector<std::unique_ptr<KernelBase>> kernels;
...@@ -74,49 +119,41 @@ int Engine::BuildOriginProgram() { ...@@ -74,49 +119,41 @@ int Engine::BuildOriginProgram() {
} }
origin_program_.emplace_back(std::move(op), std::move(picked_kernel)); origin_program_.emplace_back(std::move(op), std::move(picked_kernel));
} }
return 0; CHECK(!origin_program_.empty()) << "no instructions";
return true;
} }
int Engine::LaunchOriginProgram() { bool Engine::LaunchOriginProgram() {
for (auto& inst : origin_program_) { if (origin_program_.empty()) {
auto op_type = inst.op()->op_info()->Type(); BuildOriginProgram();
if (op_type == "feed" || op_type == "fetch") continue; }
inst.Run(); if (!origin_program_.empty()) {
for (auto &inst : origin_program_) {
auto op_type = inst.op()->op_info()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
inst.Run();
}
return true;
} }
return 0; return false;
} }
int Engine::Build() { bool Engine::PrepareWorkspaceForDeviceProgram() {
// In order to attach all of the ops of the block desc, we need to build the return PrepareWorkspaceForOriginProgram();
// original program firstly.
BuildOriginProgram();
// Run InferShape() of all of ops, and convert Paddle ops to NPU/XPU IR graph
build_device_program_status_ = BuildDeviceProgram();
return build_device_program_status_;
} }
bool Engine::BuildDeviceProgram() { return BuildOriginProgram(); }
bool Engine::LaunchDeviceProgram() { return LaunchOriginProgram(); }
bool Engine::InputShapeChanged() { bool Engine::InputShapeChanged() {
bool changed = false;
for (size_t i = 0; i < origin_itensors_.size(); i++) { for (size_t i = 0; i < origin_itensors_.size(); i++) {
if (origin_itensors_[i]->dims() != origin_idims_[i]) { auto origin_idim = origin_itensors_[i]->dims().Vectorize();
return true; changed |= origin_idim != origin_idims_[i];
} origin_idims_[i] = origin_idim;
}
return false;
}
int Engine::Launch() {
// Rebuild device program when the shapes of input tensors have been changed.
if (CHECK_SUCCESS(build_device_program_status_) &&
CHECK_REBUILD_WHEN_SHAPE_CHANGED(build_device_program_status_) &&
InputShapeChanged()) {
Build();
}
if (CHECK_FAILED(build_device_program_status_)) {
LaunchOriginProgram();
} else {
LaunchDeviceProgram();
} }
return 0; return changed;
} }
} // namespace subgraph } // namespace subgraph
......
...@@ -33,42 +33,33 @@ class Engine { ...@@ -33,42 +33,33 @@ class Engine {
cpp::BlockDesc *block_desc, cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
lite::Scope *scope) lite::Scope *scope);
: ctx_(ctx),
block_idx_(block_idx),
block_desc_(block_desc),
input_names_(input_names),
output_names_(output_names),
scope_(scope) {}
virtual ~Engine() = default; virtual ~Engine() = default;
virtual int Build(); virtual bool Run();
virtual int Launch();
private: private:
Engine(const Engine &) = delete; Engine(const Engine &) = delete;
protected: protected:
virtual int BuildDeviceProgram(); virtual bool PrepareWorkspaceForOriginProgram();
virtual int LaunchDeviceProgram(); virtual bool BuildOriginProgram();
virtual bool LaunchOriginProgram();
virtual int BuildOriginProgram(); virtual bool PrepareWorkspaceForDeviceProgram();
virtual int LaunchOriginProgram(); virtual bool BuildDeviceProgram();
virtual bool LaunchDeviceProgram();
virtual bool InputShapeChanged(); virtual bool InputShapeChanged();
KernelContext *ctx_{nullptr}; KernelContext *ctx_{nullptr};
int block_idx_; int block_idx_{-1};
cpp::BlockDesc *block_desc_; cpp::BlockDesc *block_desc_{nullptr};
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
Scope *scope_{nullptr}; Scope *scope_{nullptr};
// SUCCESS: device program build successed. FAILED: device program build bool is_first_epoch_{true};
// failed. REBUILD_WHEN_SHAPE_CHANGED: device program build successed but need std::vector<std::vector<int64_t>> origin_idims_;
// to rebuild when input shape changed.
int build_device_program_status_{0};
std::vector<DDim> origin_idims_;
std::vector<DDim> origin_odims_;
std::vector<Tensor *> origin_itensors_; std::vector<Tensor *> origin_itensors_;
std::vector<Tensor *> origin_otensors_; std::vector<Tensor *> origin_otensors_;
std::vector<Instruction> origin_program_; std::vector<Instruction> origin_program_;
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "graph/op/all_ops.h" #include "graph/compatible/all_ops.h"
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
......
...@@ -94,10 +94,10 @@ int MatMulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -94,10 +94,10 @@ int MatMulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} else { } else {
matmul_node = graph->Add<ge::op::BatchMatMul>(out_name); matmul_node = graph->Add<ge::op::BatchMatMul>(out_name);
auto matmul_op = matmul_node->data<ge::op::BatchMatMul>(); auto matmul_op = matmul_node->data<ge::op::BatchMatMul>();
matmul_op->set_input_x(*x_node->data()); matmul_op->set_input_x1(*x_node->data());
matmul_op->set_input_y(*y_node->data()); matmul_op->set_input_x2(*y_node->data());
matmul_op->set_attr_adj_x(transpose_x); matmul_op->set_attr_adj_x1(transpose_x);
matmul_op->set_attr_adj_y(transpose_y); matmul_op->set_attr_adj_x2(transpose_y);
} }
if (fabs(alpha - 1.f) > 1e-6f) { if (fabs(alpha - 1.f) > 1e-6f) {
......
...@@ -20,11 +20,11 @@ ...@@ -20,11 +20,11 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "graph/buffer.h" #include "graph/buffer.h"
#include "graph/compatible/operator_reg.h"
#include "graph/graph.h" #include "graph/graph.h"
#include "graph/model.h" #include "graph/model.h"
#include "graph/op/all_ops.h" #include "graph/op/all_ops.h"
#include "graph/operator.h" #include "graph/operator.h"
#include "graph/operator_reg.h"
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/utils/macros.h" #include "lite/utils/macros.h"
...@@ -97,25 +97,26 @@ REG_OP(Pad) ...@@ -97,25 +97,26 @@ REG_OP(Pad)
/* /*
* Multiplies slices of two tensors in batches. * Multiplies slices of two tensors in batches.
* <Input> * <Input>
* x : The input tensor * x1 : The input tensor
* y : The input tensor * x2 : The input tensor
* <Output> * <Output>
* z : The output tensor * y : The output tensor
* <Attr> * <Attr>
* adj_x : adj_x is true, the input tensor x is transposed, otherwise * adj_x1 : adj_x1 is true, the input tensor x1 is transposed,
* it will not be transposed. Default is false (The current version only * otherwise it will not be transposed.
* supports false). * Default is false (The current version only supports false).
* adj_y : adj_y is true, the input tensor y is transposed, otherwise * adj_x2 : adj_x2 is true, the input tensor x2 is transposed,
* it will not be transposed. Default is false. * otherwise it will not be transposed.
* Default is false.
* <Added in HiAI version> * <Added in HiAI version>
* 100.320.010.010 * 100.320.010.010
*/ */
REG_OP(BatchMatMul) REG_OP(BatchMatMul)
.INPUT(x, TensorType({DT_FLOAT})) .INPUT(x1, TensorType({DT_FLOAT}))
.INPUT(y, TensorType({DT_FLOAT})) .INPUT(x2, TensorType({DT_FLOAT}))
.OUTPUT(z, TensorType({DT_FLOAT})) .OUTPUT(y, TensorType({DT_FLOAT}))
.ATTR(adj_x, AttrValue::BOOL{false}) .ATTR(adj_x1, AttrValue::BOOL{false})
.ATTR(adj_y, AttrValue::BOOL{false}) .ATTR(adj_x2, AttrValue::BOOL{false})
.OP_END() .OP_END()
} // namespace ge } // namespace ge
......
...@@ -28,40 +28,65 @@ namespace lite { ...@@ -28,40 +28,65 @@ namespace lite {
namespace kernels { namespace kernels {
namespace npu { namespace npu {
class DeviceProgram {
public:
DeviceProgram() {}
~DeviceProgram() {}
std::string GenerateModelName(
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims);
bool LoadFromCacheFile(const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims,
const std::string& model_cache_dir);
bool BuildGraphAndCacheToFile(
const std::vector<Instruction>& origin_program,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims,
const std::vector<Tensor*>& origin_otensors,
const std::string& model_cache_dir);
bool ShareBufferWithOriginTensors(
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
std::vector<Tensor*>* origin_itensors,
std::vector<Tensor*>* origin_otensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_itensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_otensors);
bool ZeroCopyRun(
std::vector<std::shared_ptr<hiai::AiTensor>>* device_itensors,
std::vector<std::shared_ptr<hiai::AiTensor>>* device_otensors);
public:
std::string model_name_{""};
std::shared_ptr<hiai::AiModelMngerClient> model_client_{nullptr};
std::vector<std::vector<int64_t>> origin_odims_;
std::vector<PrecisionType> origin_otypes_;
std::vector<hiai::TensorDimension> device_idims_{};
std::vector<hiai::TensorDimension> device_odims_{};
};
class SubgraphEngine : public subgraph::Engine { class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(KernelContext *ctx, SubgraphEngine(KernelContext* ctx,
int block_idx, int block_idx,
cpp::BlockDesc *block_desc, cpp::BlockDesc* block_desc,
const std::vector<std::string> &input_names, const std::vector<std::string>& input_names,
const std::vector<std::string> &output_names, const std::vector<std::string>& output_names,
Scope *scope) Scope* scope)
: subgraph::Engine( : subgraph::Engine(
ctx, block_idx, block_desc, input_names, output_names, scope) {} ctx, block_idx, block_desc, input_names, output_names, scope) {}
struct device_program_t {
explicit device_program_t(std::shared_ptr<hiai::AiModelMngerClient> _client)
: client(_client) {}
std::shared_ptr<hiai::AiModelMngerClient> client{nullptr};
std::vector<DDim> origin_idims{};
std::vector<DDim> origin_odims{};
std::vector<hiai::TensorDimension> device_idims{};
std::vector<hiai::TensorDimension> device_odims{};
};
protected: protected:
int BuildDeviceProgram() override; bool PrepareWorkspaceForDeviceProgram() override;
int LaunchDeviceProgram() override; bool BuildDeviceProgram() override;
bool InputShapeChanged() override; bool LaunchDeviceProgram() override;
std::string model_name_{"model.om"};
std::vector<std::vector<int64_t>> inputs_shape_{};
std::map<std::vector<std::vector<int64_t>>, std::shared_ptr<device_program_t>>
device_program_map_{};
std::vector<std::string> device_inames_{};
std::vector<std::string> device_onames_{};
std::vector<std::shared_ptr<hiai::AiTensor>> device_itensors_{}; std::vector<std::shared_ptr<hiai::AiTensor>> device_itensors_{};
std::vector<std::shared_ptr<hiai::AiTensor>> device_otensors_{}; std::vector<std::shared_ptr<hiai::AiTensor>> device_otensors_{};
std::map<std::vector<std::vector<int64_t>>, std::shared_ptr<DeviceProgram>>
device_programs_;
}; };
class SubgraphCompute : public KernelLite<TARGET(kNPU), PRECISION(kAny)> { class SubgraphCompute : public KernelLite<TARGET(kNPU), PRECISION(kAny)> {
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \ #define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \
"SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE" "SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE"
#define SUBGRAPH_DISABLE_ONLINE_MODE "SUBGRAPH_DISABLE_ONLINE_MODE"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -119,5 +119,40 @@ static std::vector<std::string> ListDir(const std::string& path, ...@@ -119,5 +119,40 @@ static std::vector<std::string> ListDir(const std::string& path,
return paths; return paths;
} }
static bool ReadFile(const std::string& filename, std::vector<char>* contents) {
FILE* fp = fopen(filename.c_str(), "rb");
if (!fp) return false;
fseek(fp, 0, SEEK_END);
size_t size = ftell(fp);
fseek(fp, 0, SEEK_SET);
contents->clear();
contents->resize(size);
size_t offset = 0;
char* ptr = reinterpret_cast<char*>(&(contents->at(0)));
while (offset < size) {
size_t already_read = fread(ptr, 1, size - offset, fp);
offset += already_read;
ptr += already_read;
}
fclose(fp);
return true;
}
static bool WriteFile(const std::string& filename,
const std::vector<char>& contents) {
FILE* fp = fopen(filename.c_str(), "wb");
if (!fp) return false;
size_t size = contents.size();
size_t offset = 0;
const char* ptr = reinterpret_cast<const char*>(&(contents.at(0)));
while (offset < size) {
size_t already_written = fwrite(ptr, 1, size - offset, fp);
offset += already_written;
ptr += already_written;
}
fclose(fp);
return true;
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace lite {
std::string MD5(std::string message) {
const uint32_t shiftAmounts[] = {
7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22,
5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20,
4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23,
6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21};
const uint32_t partsOfSines[] = {
0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a,
0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be,
0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340,
0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8,
0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8,
0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c,
0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa,
0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665,
0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92,
0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1,
0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391};
uint32_t state[4];
state[0] = 0x67452301;
state[1] = 0xefcdab89;
state[2] = 0x98badcfe;
state[3] = 0x10325476;
// Pad with zeros
int size = ((((message.length() + 8) / 64) + 1) * 64) - 8;
uint8_t *buf = reinterpret_cast<uint8_t *>(calloc(size + 64, 1));
memcpy(buf, message.c_str(), message.length());
buf[message.length()] = 128;
uint32_t bits = 8 * message.length();
memcpy(buf + size, &bits, 4);
// Process at each 512-bit(64 bytes) chunk
#define LEFTROTATE(x, c) (((x) << (c)) | ((x) >> (32 - (c))))
for (int offset = 0; offset < size; offset += 64) {
uint32_t A = state[0];
uint32_t B = state[1];
uint32_t C = state[2];
uint32_t D = state[3];
uint32_t *W = reinterpret_cast<uint32_t *>(buf + offset);
for (uint32_t i = 0; i < 64; i++) {
uint32_t F, g;
if (i < 16) {
F = (B & C) | ((~B) & D);
g = i;
} else if (i < 32) {
F = (D & B) | ((~D) & C);
g = (5 * i + 1) % 16;
} else if (i < 48) {
F = B ^ C ^ D;
g = (3 * i + 5) % 16;
} else {
F = C ^ (B | (~D));
g = (7 * i) % 16;
}
uint32_t T = D;
D = C;
C = B;
B = B + LEFTROTATE((A + F + partsOfSines[i] + W[g]), shiftAmounts[i]);
A = T;
}
state[0] += A;
state[1] += B;
state[2] += C;
state[3] += D;
}
#undef LEFTROTATE
free(buf);
// Convert digest to string
std::string res;
res.reserve(16 << 1);
const uint8_t *digest = reinterpret_cast<uint8_t *>(state);
char hex[3];
for (size_t i = 0; i < 16; i++) {
snprintf(hex, sizeof(hex), "%02x", digest[i]);
res.append(hex);
}
return res;
}
} // namespace lite
} // namespace paddle
...@@ -60,6 +60,38 @@ static std::string to_string(const T& v) { ...@@ -60,6 +60,38 @@ static std::string to_string(const T& v) {
return ss.str(); return ss.str();
} }
static std::string to_string(int index) {
const int BUFFER_LENGTH = 15;
char buffer[BUFFER_LENGTH];
snprintf(buffer, sizeof(buffer), "%d", index);
return std::string(buffer);
}
template <typename T = std::string>
static T parse_string(const std::string& v) {
return v;
}
template <>
int32_t parse_string<int32_t>(const std::string& v) {
return std::stoi(v);
}
template <>
int64_t parse_string<int64_t>(const std::string& v) {
return std::stoll(v);
}
template <>
float parse_string<float>(const std::string& v) {
return std::stof(v);
}
template <>
double parse_string<double>(const std::string& v) {
return std::stod(v);
}
template <typename T> template <typename T>
std::string Join(const std::vector<T>& vec, const std::string& delim) { std::string Join(const std::vector<T>& vec, const std::string& delim) {
if (vec.empty()) return ""; if (vec.empty()) return "";
...@@ -84,19 +116,20 @@ static std::string Repr(const std::vector<std::string>& v) { ...@@ -84,19 +116,20 @@ static std::string Repr(const std::vector<std::string>& v) {
return "{" + Join(tmp, ",") + "}"; return "{" + Join(tmp, ",") + "}";
} }
static std::vector<std::string> Split(const std::string& original, template <class T = std::string>
const std::string& separator) { static std::vector<T> Split(const std::string& original,
std::vector<std::string> results; const std::string& separator) {
std::vector<T> results;
std::string::size_type pos1, pos2; std::string::size_type pos1, pos2;
pos2 = original.find(separator); pos2 = original.find(separator);
pos1 = 0; pos1 = 0;
while (std::string::npos != pos2) { while (std::string::npos != pos2) {
results.push_back(original.substr(pos1, pos2 - pos1)); results.push_back(parse_string<T>(original.substr(pos1, pos2 - pos1)));
pos1 = pos2 + separator.size(); pos1 = pos2 + separator.size();
pos2 = original.find(separator, pos1); pos2 = original.find(separator, pos1);
} }
if (pos1 != original.length()) { if (pos1 != original.length()) {
results.push_back(original.substr(pos1)); results.push_back(parse_string<T>(original.substr(pos1)));
} }
return results; return results;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册