未验证 提交 7a731b7f 编写于 作者: H hong19860320 提交者: GitHub

[NPU] Fix and refine the supporting of multi NPU models (#2037)

* [NPU] Fix the bug of loading multi NPU models
test=develop

* [NPU] Use lite tensor to store NPU model, fix the management of multi NPU models, support loading NPU model from memory and reduce the modification of framework
test=develop

* [NPU] Remove redundant header files for NPU bridges,
test=develop

* [NPU] fix NPU deps
test=develop

* [NPU] refine the compiling script for NPU
test=develop

* [NPU] remove redundant subdirectory in lite/CMakeLists.txt
test=develop

* [NPU] Fix and refine NPU test case
test=develop

* [NPU] revoke the modification of other non-NPU modules
test=develop

* [NPU] Remove NPU bridges if target is tiny publish
test=develop
上级 421c6305
......@@ -75,7 +75,7 @@ lite_cc_library(light_api SRCS light_api.cc
CUDA_DEPS ${cuda_kernels}
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass
NPU_DEPS ${npu_kernels}
CL_DEPS ${opencl_kenrels}
FPGA_DEPS ${fpga_kenrels})
......
......@@ -18,9 +18,6 @@
#include <utility>
#include <vector>
#include "lite/utils/io.h"
#ifdef LITE_WITH_NPU
#include "lite/backends/npu/npu_helper.h"
#endif
namespace paddle {
namespace lite {
......@@ -42,16 +39,6 @@ void Predictor::SaveModel(const std::string &dir,
default:
LOG(FATAL) << "Unknown model type";
}
#ifdef LITE_WITH_NPU
for (auto name : npu::DeviceInfo::Global().AllClientNames()) {
// the npu offline model is saved in current dir
// so just copy to dst dir
CHECK_EQ(
system(string_format("cp -r %s %s", name.c_str(), dir.c_str()).c_str()),
0)
<< "Failed copy NPU model to " << dir;
}
#endif
}
lite::Tensor *Predictor::GetInput(size_t offset) {
......
......@@ -2,5 +2,8 @@ if(NOT LITE_WITH_NPU)
return()
endif()
lite_cc_library(npu_helper SRCS npu_helper.cc DEPS ${npu_ddk_libs})
add_subdirectory(bridge)
lite_cc_library(npu_runtime SRCS runtime.cc DEPS npu_ddk_hiai)
if(NOT LITE_ON_TINY_PUBLISH)
add_subdirectory(bridge)
endif()
lite_cc_library(npu_bridge_registry SRCS registry.cc DEPS ${npu_ddk_libs})
lite_cc_library(npu_bridge_utils SRCS utils.cc DEPS ${npu_ddk_libs} tensor op mir_node scope)
lite_cc_library(npu_bridge_utils SRCS utils.cc DEPS ${npu_ddk_libs} npu_runtime tensor op scope)
set(npu_bridge_deps npu_bridge_registry npu_bridge_utils op)
......@@ -12,7 +11,7 @@ lite_cc_library(npu_bridge_scale_op SRCS scale_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_softmax_op SRCS softmax_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_pool_op SRCS pool_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_batch_norm_op SRCS batch_norm_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_elementwise_op SRCS elementwise_ops.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_elementwise_ops SRCS elementwise_ops.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_reshape_op SRCS reshape_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_conv_transpose_op SRCS conv_transpose_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_interpolate_op SRCS interpolate_op.cc DEPS ${npu_bridge_deps})
......@@ -33,7 +32,7 @@ set(npu_bridges
npu_bridge_softmax_op
npu_bridge_pool_op
npu_bridge_batch_norm_op
npu_bridge_elementwise_op
npu_bridge_elementwise_ops
npu_bridge_reshape_op
npu_bridge_conv_transpose_op
npu_bridge_interpolate_op
......@@ -44,24 +43,24 @@ set(npu_bridges
npu_bridge_pad2d_op
CACHE INTERNAL "npu_bridges")
lite_cc_library(npu_test_helper SRCS test_helper.cc DEPS npu_helper ${npu_ddk_libs} ${npu_bridges} ${npu_kernels} ${ops})
set(npu_bridge_test_deps ${npu_ddk_libs} ${npu_bridges} ${npu_kernels} ${ops})
lite_cc_test(test_npu_bridge_fc_op SRCS fc_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_conv_op SRCS conv_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_mul_op SRCS mul_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_act_op SRCS act_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_scale_op SRCS scale_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_softmax_op SRCS softmax_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_pool_op SRCS pool_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_batch_norm_op SRCS batch_norm_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_elementwise_op SRCS elementwise_ops_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_reshape_op SRCS reshape_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_conv_transpose_op SRCS conv_transpose_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_interpolate_op SRCS interpolate_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_transpose_op SRCS transpose_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_pad2d_op SRCS pad2d_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_fc_op SRCS fc_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_conv_op SRCS conv_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_mul_op SRCS mul_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_act_op SRCS act_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_scale_op SRCS scale_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_softmax_op SRCS softmax_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_pool_op SRCS pool_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_batch_norm_op SRCS batch_norm_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_elementwise_ops SRCS elementwise_ops_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_reshape_op SRCS reshape_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_conv_transpose_op SRCS conv_transpose_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_interpolate_op SRCS interpolate_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_transpose_op SRCS transpose_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_pad2d_op SRCS pad2d_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
message(STATUS "+++++ npu_bridges: ${npu_bridges}")
......@@ -20,7 +20,6 @@
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/bridge/registry.h"
#include "lite/backends/npu/bridge/utils.h"
#include "lite/operators/relu_op.h"
namespace paddle {
namespace lite {
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/batch_norm_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/concat_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......@@ -21,7 +20,6 @@
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/bridge/registry.h"
#include "lite/backends/npu/bridge/utils.h"
#include "lite/backends/npu/npu_helper.h"
namespace paddle {
namespace lite {
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conv_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conv_transpose_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/elementwise_ops.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fc_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -67,7 +67,7 @@ void fc_ref(const std::shared_ptr<operators::FcOpLite> op) {
}
}
void test_fc(const std::vector<int64_t>& x_shape,
void test_fc(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& w_shape,
int in_num_col_dims,
bool has_bias) {
......@@ -78,30 +78,25 @@ void test_fc(const std::vector<int64_t>& x_shape,
CHECK(bridges.HasType("fc"));
Scope scope;
std::string x_var_name("Input");
std::string input_var_name("Input");
std::string w_var_name("W");
std::string bias_var_name("Bias");
std::string out_var_name("Out");
std::string out_ref_var_name("out_ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* input = scope.Var(input_var_name)->GetMutable<Tensor>();
auto* w = scope.Var(w_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(x_shape);
input->Resize({bs, ic, ih, iw});
// get w shape
auto in_mat_dims = input->dims().Flatten2D(in_num_col_dims);
std::vector<int64_t> w_shape = {in_mat_dims[1], out_num_classes};
input->Resize(input_shape);
w->Resize(w_shape);
FillTensor<float, int>(x);
FillTensor<float, int>(input);
FillTensor<float, int>(w);
// create fc op
cpp::OpDesc fc_op_desc;
fc_op_desc.SetType("fc");
fc_op_desc.SetInput("Input", {x_var_name});
fc_op_desc.SetInput("Input", {input_var_name});
fc_op_desc.SetInput("W", {w_var_name});
fc_op_desc.SetOutput("Out", {out_var_name});
fc_op_desc.SetAttr("in_num_col_dims", static_cast<int>(in_num_col_dims));
......@@ -113,7 +108,7 @@ void test_fc(const std::vector<int64_t>& x_shape,
}
auto fc_op = CreateOp<operators::FcOpLite>(fc_op_desc, &scope);
LauchOp(fc_op, {x_var_name}, {out_var_name});
LauchOp(fc_op, {input_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// compare results
......@@ -123,10 +118,6 @@ void test_fc(const std::vector<int64_t>& x_shape,
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
// model release
npu::OpList::Global().clear();
npu::DeviceInfo::Global().Clear();
}
TEST(NPUBridges, fc) {
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/mul_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......@@ -21,7 +20,6 @@
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/bridge/registry.h"
#include "lite/backends/npu/bridge/utils.h"
#include "lite/backends/npu/npu_helper.h"
namespace paddle {
namespace lite {
......
......@@ -69,15 +69,6 @@ void test_mul(const std::vector<int64_t>& x_shape,
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(x_shape);
// get y shape
auto x_mat_dims = x->dims().Flatten2D(x_num_col_dims);
std::vector<int64_t> y_shape;
for (int i = 0; i < y_num_col_dims - 1; i++) {
y_shape.push_back(1);
}
y_shape.push_back(x_mat_dims[1]);
y_shape.push_back(o);
y->Resize(y_shape);
FillTensor<float, int>(x);
......@@ -104,10 +95,6 @@ void test_mul(const std::vector<int64_t>& x_shape,
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
// model release
npu::OpList::Global().clear();
npu::DeviceInfo::Global().Clear();
}
TEST(NPUBridges, mul) {
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/pool_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/scale_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/shuffle_channel_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/softmax_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/split_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......@@ -21,7 +20,6 @@
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/bridge/registry.h"
#include "lite/backends/npu/bridge/utils.h"
#include "lite/backends/npu/npu_helper.h"
namespace paddle {
namespace lite {
......
......@@ -43,7 +43,7 @@ void LauchOp(const std::shared_ptr<lite::OpLite> op,
ge::Shape(input->dims().Vectorize()), ge::FORMAT_NCHW, ge::DT_FLOAT);
auto input_node = std::make_shared<ge::op::Data>(input_var_name);
input_node->update_input_desc_x(input_desc);
npu::OpList::Global().add(input_node);
OpList::Global().add(input_node);
inputs_map[input_var_name] = input_node;
}
auto outputs_map = supported_lists.at(op_type)(op, inputs_map);
......@@ -58,15 +58,20 @@ void LauchOp(const std::shared_ptr<lite::OpLite> op,
for (auto output_var_name : output_var_names) {
graph_outputs.push_back(*outputs_map[output_var_name]);
}
std::string model_name(UniqueName("test_" + op_type) + ".om");
CHECK(npu::BuildNPUClient(graph_inputs, graph_outputs, model_name));
std::string weight_var_name = "weight";
auto weight = scope->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
CHECK(BuildModel(graph_inputs, graph_outputs, weight));
CHECK_GT(weight->numel(), 0);
CHECK_NE(weight->data<uint8_t>(), 0);
// create graph op and set inputs and outputs
cpp::OpDesc graph_op_desc;
graph_op_desc.SetType("graph_op");
graph_op_desc.SetInput("Inputs", input_var_names);
graph_op_desc.SetInput("Weight", {weight_var_name});
graph_op_desc.SetOutput("Outputs", output_var_names);
graph_op_desc.SetAttr("model_name", model_name);
auto graph_op =
std::make_shared<operators::GraphOpLite>(graph_op_desc.Type());
......@@ -88,8 +93,7 @@ void LauchOp(const std::shared_ptr<lite::OpLite> op,
graph_kernel->Launch();
// release all of resources of generated model
npu::OpList::Global().clear();
npu::DeviceInfo::Global().Clear();
OpList::Global().clear();
}
} // namespace bridge
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/transpose_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
......
......@@ -13,19 +13,52 @@
// limitations under the License.
#include "lite/backends/npu/bridge/utils.h"
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h" // for ge::op::Data
#include "ai_ddk_lib/include/graph/tensor.h" // for ge::TensorUtils
#include "lite/core/op_lite.h"
#include "ai_ddk_lib/include/hiai_ir_build.h"
#include "lite/backends/npu/runtime.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
// Build HIAI IR graph to om model, and store om model data into lite tensor
bool BuildModel(std::vector<ge::Operator>& inputs, // NOLINT
std::vector<ge::Operator>& outputs, // NOLINT
lite::Tensor* model_data) {
LOG(INFO) << "[NPU] Build model.";
CHECK_GT(inputs.size(), 0);
CHECK_GT(outputs.size(), 0);
CHECK_NE(model_data, 0);
// build IR graph to om model
ge::Graph ir_graph("graph");
ir_graph.SetInputs(inputs).SetOutputs(outputs);
ge::Model om_model("model", "model");
om_model.SetGraph(ir_graph);
domi::HiaiIrBuild ir_build;
domi::ModelBufferData om_model_buf;
if (!ir_build.CreateModelBuff(om_model, om_model_buf)) {
LOG(WARNING) << "[NPU] CreateModelBuff failed!";
return false;
}
if (!ir_build.BuildIRModel(om_model, om_model_buf)) {
LOG(WARNING) << "[NPU] BuildIRModel failed!";
return false;
}
// store om model into tensor
model_data->Resize({om_model_buf.length});
memcpy(model_data->mutable_data<int8_t>(),
om_model_buf.data,
om_model_buf.length);
ir_build.ReleaseModelBuff(om_model_buf);
return true;
}
std::string UniqueName(const std::string& prefix) {
static std::mutex counter_mtx;
static std::unordered_map<std::string, int> counter_map;
......
......@@ -19,7 +19,6 @@
#include <unordered_map>
#include <vector>
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/core/mir/node.h"
#include "lite/core/op_lite.h"
#include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h"
......@@ -29,6 +28,24 @@ namespace lite {
namespace npu {
namespace bridge {
class OpList {
public:
static OpList& Global() {
static thread_local OpList x;
return x;
}
void clear() { lists_.clear(); }
void add(std::shared_ptr<ge::Operator> p) { lists_.push_back(p); }
private:
std::vector<std::shared_ptr<ge::Operator>> lists_;
};
// Build HIAI IR graph to om model, and store om model data into lite tensor
bool BuildModel(std::vector<ge::Operator>& inputs, // NOLINT
std::vector<ge::Operator>& outputs, // NOLINT
lite::Tensor* model_data);
std::string UniqueName(const std::string& prefix);
ge::DataType PrecisionConverter(PrecisionType itype);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/npu_helper.h"
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/hiai_ir_build.h"
namespace paddle {
namespace lite {
namespace npu {
bool SaveNPUModel(const void* om_model_data,
const size_t om_model_size,
const std::string& om_file_path) {
std::FILE* fp;
fp = std::fopen(om_file_path.c_str(), "wb");
if (fp == NULL) {
LOG(WARNING) << "[NPU] " << om_file_path << " open failed!";
return false;
}
size_t write_size = std::fwrite(om_model_data, 1, om_model_size, fp);
if (write_size != om_model_size) {
std::fclose(fp);
LOG(WARNING) << "[NPU] Write NPU model failed: " << om_file_path;
return false;
}
std::fclose(fp);
return true;
}
bool BuildNPUClient(const void* om_model_data,
const size_t om_model_size,
const std::string& name) {
std::unique_ptr<hiai::AiModelMngerClient> client(
new hiai::AiModelMngerClient);
int ret = client->Init(nullptr);
if (ret != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] Failed building NPU client " << name
<< ", ret: " << ret;
throw std::runtime_error("");
return false;
}
auto desc = std::make_shared<hiai::AiModelDescription>(
name,
DeviceInfo::Global().freq_level(),
DeviceInfo::Global().framework_type(),
DeviceInfo::Global().model_type(),
DeviceInfo::Global().device_type());
desc->SetModelBuffer(om_model_data, om_model_size);
std::vector<std::shared_ptr<hiai::AiModelDescription>> model_desc;
model_desc.push_back(desc);
if (client->Load(model_desc) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] Model Load Failed: " << desc->GetName();
throw std::runtime_error("");
return false;
}
DeviceInfo::Global().Insert(name, std::move(client));
return true;
}
// If build from inputs and outputs will save the npu offline model
bool BuildNPUClient(std::vector<ge::Operator>& inputs, // NOLINT
std::vector<ge::Operator>& outputs, // NOLINT
const std::string& name) {
LOG(INFO) << "[NPU] Building Client";
ge::Graph npu_subgraph("npu_subgraph" + name);
npu_subgraph.SetInputs(inputs).SetOutputs(outputs);
ge::Model npu_model("model", "npu_model" + name);
npu_model.SetGraph(npu_subgraph);
// compile IR graph and output om model to memory
domi::HiaiIrBuild ir_build;
domi::ModelBufferData om_model_buffer;
if (!ir_build.CreateModelBuff(npu_model, om_model_buffer)) {
LOG(WARNING) << "[NPU] Failed CreateModelBuff: " << npu_model.GetName();
return false;
}
if (!ir_build.BuildIRModel(npu_model, om_model_buffer)) {
LOG(WARNING) << "[NPU] Failed BuildIRModel: " << npu_model.GetName();
return false;
}
if (BuildNPUClient(om_model_buffer.data, om_model_buffer.length, name)) {
// save npu offline model
if (!SaveNPUModel(om_model_buffer.data, om_model_buffer.length, name)) {
LOG(WARNING) << "[NPU] Save model " << name << " failed.";
}
ir_build.ReleaseModelBuff(om_model_buffer);
return true;
}
return false;
}
// If build from path will not save the npu offline model
bool BuildNPUClient(const std::string& om_model_file_path,
const std::string& name) {
// load om model from file
std::ifstream file(om_model_file_path, std::ios::binary);
CHECK(file.is_open()) << "[NPU] Unable to open om model file: "
<< om_model_file_path;
const auto fbegin = file.tellg();
file.seekg(0, std::ios::end);
const auto fend = file.tellg();
size_t om_model_size = fend - fbegin;
VLOG(5) << "[NPU] om model file size: " << om_model_size;
file.seekg(0, std::ios::beg);
std::vector<char> om_model_data(om_model_size);
file.read(om_model_data.data(), om_model_size);
return BuildNPUClient(
reinterpret_cast<void*>(om_model_data.data()), om_model_size, name);
}
} // namespace npu
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/runtime.h"
#include <string>
#include <vector>
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace npu {
// Create hiai model manager to load om model from lite tensor, and return the
// manager and an unique model name
bool LoadModel(const lite::Tensor &model_data,
std::shared_ptr<hiai::AiModelMngerClient> *model_client,
std::string *model_name) {
LOG(INFO) << "[NPU] Load model.";
auto model_data_ptr = model_data.data<int8_t>();
auto model_data_size = model_data.numel() * sizeof(int8_t);
if (model_data_ptr == nullptr || model_data_size == 0) {
return false;
}
*model_client = std::make_shared<hiai::AiModelMngerClient>();
int ret = (*model_client)->Init(nullptr);
if (ret != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient init failed(" << ret << ")!";
return false;
}
*model_name = "model.om";
auto model_desc = std::make_shared<hiai::AiModelDescription>(
*model_name,
DeviceInfo::Global().freq_level(),
DeviceInfo::Global().framework_type(),
DeviceInfo::Global().model_type(),
DeviceInfo::Global().device_type());
model_desc->SetModelBuffer(model_data_ptr, model_data_size);
std::vector<std::shared_ptr<hiai::AiModelDescription>> model_descs;
model_descs.push_back(model_desc);
if ((*model_client)->Load(model_descs) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!";
return false;
}
return true;
}
} // namespace npu
} // namespace lite
} // namespace paddle
......@@ -15,13 +15,8 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/utils/cp_logging.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
......@@ -29,36 +24,11 @@ namespace npu {
class DeviceInfo {
public:
static DeviceInfo& Global() {
static DeviceInfo &Global() {
static DeviceInfo x;
return x;
}
DeviceInfo() {}
void Insert(const std::string& name,
std::unique_ptr<hiai::AiModelMngerClient> client) {
if (clients_.find(name) != clients_.end()) {
LOG(WARNING) << "[NPU] Already insert " << name;
return;
}
clients_.emplace(std::make_pair(name, std::move(client)));
}
void Clear() { clients_.clear(); }
hiai::AiModelMngerClient* client(const std::string& model_name) const {
if (clients_.find(model_name) != clients_.end()) {
return clients_.at(model_name).get();
} else {
return nullptr;
}
}
std::vector<std::string> AllClientNames() {
std::vector<std::string> names;
for (auto& i : clients_) {
names.push_back(i.first);
}
return names;
}
int freq_level() { return freq_level_; }
int framework_type() { return framework_type_; }
......@@ -70,41 +40,11 @@ class DeviceInfo {
int framework_type_{0};
int model_type_{0};
int device_type_{0};
// TODO(TJ): find better place
std::unordered_map<std::string, std::unique_ptr<hiai::AiModelMngerClient>>
clients_;
};
class OpList {
public:
static OpList& Global() {
static thread_local OpList x;
return x;
}
void clear() { lists_.clear(); }
void add(std::shared_ptr<ge::Operator> p) { lists_.push_back(p); }
private:
std::vector<std::shared_ptr<ge::Operator>> lists_;
};
bool SaveNPUModel(const void* om_model_data,
const size_t om_model_size,
const std::string& om_file_path);
// If build from inputs and outputs will save the npu offline model
bool BuildNPUClient(std::vector<ge::Operator>& inputs, // NOLINT
std::vector<ge::Operator>& outputs, // NOLINT
const std::string& name);
// If build from path will not save the npu offline model
bool BuildNPUClient(const std::string& om_model_file_path,
const std::string& name);
bool BuildNPUClient(const void* om_model_data,
const size_t om_model_size,
const std::string& name);
bool LoadModel(const lite::Tensor &model_data,
std::shared_ptr<hiai::AiModelMngerClient> *model_client,
std::string *model_name);
} // namespace npu
} // namespace lite
} // namespace paddle
......@@ -33,7 +33,7 @@ lite_cc_library(scope SRCS scope.cc DEPS tensor)
lite_cc_library(device_info SRCS device_info.cc DEPS tensor)
if (LITE_WITH_ARM)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS ${npu_ddk_libs})
lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS npu_runtime)
else()
lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags)
endif()
......
......@@ -26,7 +26,7 @@
#include "lite/backends/opencl/cl_runtime.h"
#endif
#ifdef LITE_WITH_NPU
#include "lite/backends/npu/npu_helper.h"
#include "lite/backends/npu/runtime.h"
#endif
#include <map>
......@@ -81,9 +81,6 @@ class Context<TargetType::kNPU> {
NPUContext& operator=(const NPUContext& ctx) {}
std::string name() const { return "NPUContext"; }
hiai::AiModelMngerClient* client(const std::string& model_name) const {
return npu::DeviceInfo::Global().client(model_name);
}
};
#endif
......
......@@ -16,10 +16,10 @@ set(subgraph_passes subgraph_pass)
if(LITE_WITH_NPU)
lite_cc_library(npu_pass SRCS generate_npu_program_pass.cc
DEPS mir_pass types context ${mir_fusers} ${npu_bridges} npu_helper ${npu_ddk_libs} graph_op subgraph_pass)
DEPS mir_pass types context ${mir_fusers} ${npu_bridges} ${npu_ddk_libs} graph_op subgraph_pass)
list(APPEND subgraph_passes npu_pass)
lite_cc_test(test_npu_pass SRCS generate_npu_program_pass_test.cc
DEPS npu_pass cxx_api mir_passes gflags
DEPS npu_pass mir_passes paddle_api_full paddle_api_light gflags
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1
--optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL)
if (WITH_TESTING)
......
......@@ -30,7 +30,6 @@
#include "lite/backends/npu/bridge/paddle_use_npu_bridges.h"
#include "lite/backends/npu/bridge/registry.h"
#include "lite/backends/npu/bridge/utils.h"
#include "lite/backends/npu/npu_helper.h"
namespace paddle {
namespace lite {
......@@ -125,13 +124,20 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
outputs.push_back(*converted_vars.at(argname));
}
std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om");
if (!npu::BuildNPUClient(inputs, outputs, model_name)) {
std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights";
auto any_op = (*op_nodes.begin())->AsStmt().op();
auto weight = any_op->scope()->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
// Compiling IR graph to NPU model and store mode data into weight tensor with
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
if (!lite::npu::bridge::BuildModel(inputs, outputs, weight)) {
LOG(WARNING) << "Build NPU failed subgraph " << sub_id;
throw std::runtime_error("Build NPU failed subgraph.");
}
LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id;
return model_name;
return weight_var_name;
}
void GenerateNPUProgramPass::GenNPUSubgraph(
......@@ -145,12 +151,12 @@ void GenerateNPUProgramPass::GenNPUSubgraph(
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto model_name =
auto weight_var_name =
BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
auto any_op = (*op_nodes.begin())->AsStmt().op();
InsertNewNode(graph,
model_name,
weight_var_name,
any_op->scope(),
any_op->valid_places(),
in_data_vars,
......
......@@ -21,7 +21,7 @@
#include <unordered_set>
#include <vector>
#include "lite/backends/npu/bridge/registry.h"
#include "lite/backends/npu/npu_helper.h"
#include "lite/backends/npu/bridge/utils.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
......
......@@ -12,102 +12,164 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/core/op_registry.h"
#include "lite/core/program.h"
#include "lite/core/tensor.h"
#include "lite/api/cxx_api.h"
#include <cmath>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
#include "lite/model_parser/pb/program_desc.h"
DEFINE_string(optimized_model, "", "optimized_model");
DEFINE_int32(batch_size, 1, "batch size");
DEFINE_int32(im_channel, 3, "im_channel");
DEFINE_string(model_file, "", "model file path of combined protobuf model");
DEFINE_string(params_file, "", "params file path of combined protobuf model");
DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors");
DEFINE_int32(output_tensor_num, 1, "number of output tensors");
namespace paddle {
namespace lite {
void TestModel(lite::Predictor* predictor,
const std::vector<Place>& valid_places,
const std::string& model_dir) {
predictor->Build(model_dir,
model_dir + "/model",
model_dir + "/params",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
auto* input_tensor = predictor->GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>(
{FLAGS_batch_size, FLAGS_im_channel, FLAGS_im_height, FLAGS_im_width})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
std::vector<std::vector<int64_t>> ParseShape(std::string txt) {
std::vector<std::vector<int64_t>> shape;
while (!txt.empty()) {
size_t idx = txt.find_first_of(":");
std::string dims = txt.substr(0, idx);
std::vector<int64_t> s;
while (!dims.empty()) {
size_t idx = dims.find_first_of(",");
int d = atoi(dims.substr(0, idx).c_str());
VLOG(3) << d;
s.push_back(d);
if (idx == std::string::npos) {
break;
} else {
dims = dims.substr(idx + 1);
}
}
shape.push_back(s);
if (idx == std::string::npos) {
break;
} else {
txt = txt.substr(idx + 1);
}
}
return shape;
}
predictor->Run();
if (model_dir != FLAGS_optimized_model &&
std::find(valid_places.begin(),
valid_places.end(),
Place{TARGET(kNPU), PRECISION(kFloat)}) != valid_places.end()) {
predictor->SaveModel(FLAGS_optimized_model);
int64_t ShapeProduction(std::vector<int64_t> shape) {
int64_t s = 1;
for (int64_t dim : shape) {
s *= dim;
}
return s;
}
void CompareOutData(const lite::Predictor& tgt, const lite::Predictor& ref) {
auto* tgt_otensor = tgt.GetOutput(0);
auto* ref_otensor = ref.GetOutput(0);
const auto* tgt_pdata = tgt_otensor->data<float>();
const auto* ref_pdata = ref_otensor->data<float>();
EXPECT_EQ(tgt_otensor->dims().production(), ref_otensor->dims().production());
for (size_t i = 0; i < tgt_otensor->dims().production(); ++i) {
auto diff = std::fabs(tgt_pdata[i] - ref_pdata[i]) /
(std::fabs(ref_pdata[i]) + 1e-6);
VLOG(3) << diff;
EXPECT_LT(diff, 0.1);
void FillInputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& predictor,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const float value) {
for (int i = 0; i < input_tensor_shape.size(); i++) {
auto input_tensor = predictor->GetInput(i);
input_tensor->Resize(input_tensor_shape[i]);
auto input_tensor_data = input_tensor->mutable_data<float>();
auto input_tensor_size = ShapeProduction(input_tensor->shape());
for (int j = 0; j < input_tensor_size; j++) {
input_tensor_data[i] = value;
}
}
}
TEST(NPUSubgraph, compare) {
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, 1);
lite::Predictor predictor_arm, predictor_npu, predictor_npu_savedmodel;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
TestModel(&predictor_arm, valid_places, FLAGS_model_dir);
valid_places.push_back(Place{TARGET(kNPU), PRECISION(kFloat)});
TestModel(&predictor_npu, valid_places, FLAGS_model_dir);
CompareOutData(predictor_npu, predictor_arm);
LOG(INFO) << " ================ NPU speed ================== ";
for (int i = 0; i < FLAGS_repeats; ++i) {
auto start = GetCurrentUS();
predictor_npu.Run();
LOG(INFO) << i << ", " << GetCurrentUS() - start << "us";
void CompareOutputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& tar_predictor,
const std::shared_ptr<lite_api::PaddlePredictor>& ref_predictor,
const int output_tensor_num) {
for (int i = 0; i < output_tensor_num; i++) {
auto tar_output_tensor = tar_predictor->GetOutput(i);
auto ref_output_tensor = ref_predictor->GetOutput(i);
auto tar_output_tensor_data = tar_output_tensor->data<float>();
auto ref_output_tensor_data = ref_output_tensor->data<float>();
auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape());
auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape());
EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size);
for (size_t j = 0; j < ref_output_tensor_size; j++) {
auto diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]) /
(std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(3) << diff;
EXPECT_LT(diff, 0.1);
}
}
}
LOG(INFO) << " =================== ARM CPU speed =================== ";
for (int i = 0; i < FLAGS_repeats; ++i) {
std::shared_ptr<lite_api::PaddlePredictor> TestModel(
const std::string& model_dir,
const std::string& model_file,
const std::string& params_file,
const lite_api::Place& preferred_place,
const std::vector<lite_api::Place>& valid_places,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const std::string& optimized_model_dir) {
// generate optimized model
lite_api::CxxConfig cxx_config;
cxx_config.set_model_dir(model_dir);
cxx_config.set_model_file(model_file);
cxx_config.set_param_file(params_file);
cxx_config.set_preferred_place(preferred_place);
cxx_config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(cxx_config);
FillInputTensor(predictor, input_tensor_shape, 1);
predictor->SaveOptimizedModel(optimized_model_dir,
lite_api::LiteModelType::kNaiveBuffer);
// load optimized model
lite_api::MobileConfig mobile_config;
mobile_config.set_model_dir(optimized_model_dir);
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensor(predictor, input_tensor_shape, 1);
// run optimized model
for (int i = 0; i < FLAGS_warmup; i++) {
predictor->Run();
}
for (int i = 0; i < FLAGS_repeats; i++) {
auto start = GetCurrentUS();
predictor_arm.Run();
predictor->Run();
LOG(INFO) << i << ", " << GetCurrentUS() - start << "us";
}
return predictor;
}
TestModel(&predictor_npu_savedmodel, valid_places, FLAGS_optimized_model);
CompareOutData(predictor_npu_savedmodel, predictor_arm);
TEST(NPUSubgraph, compare) {
// parsing input tensor shape, supported formats: "1,3,224,224"
// "1,3,224,224:1,80"
std::vector<std::vector<int64_t>> input_tensor_shape =
ParseShape(FLAGS_input_tensor_shape);
// generate and run optimized CPU model
LOG(INFO) << " ================ CPU ================== ";
auto cpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
{lite_api::Place{TARGET(kHost), PRECISION(kFloat)},
lite_api::Place{TARGET(kARM), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/CPU");
// generate and run optimized NPU model
LOG(INFO) << " ================ NPU ================== ";
auto npu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
{lite_api::Place{TARGET(kHost), PRECISION(kFloat)},
lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/NPU");
// verify results
CompareOutputTensor(npu_predictor, cpu_predictor, FLAGS_output_tensor_num);
}
} // namespace lite
......
......@@ -43,20 +43,20 @@ SubgraphProgramPass::ClassifySubgraph(const std::unique_ptr<SSAGraph>& graph) {
}
cpp::OpDesc SubgraphProgramPass::GenGraphOpDesc(
const std::string& model_name,
const std::string& weight_var_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
cpp::OpDesc op_desc;
op_desc.SetType("graph_op");
op_desc.SetInput("Inputs", in_var_names);
op_desc.SetInput("Weight", {weight_var_name});
op_desc.SetOutput("Outputs", out_var_names);
op_desc.SetAttr("model_name", model_name);
return op_desc;
}
void SubgraphProgramPass::InsertNewNode(
const std::unique_ptr<SSAGraph>& graph,
const std::string& model_name,
const std::string& weight_var_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
......@@ -72,7 +72,7 @@ void SubgraphProgramPass::InsertNewNode(
out_var_names.push_back(i->AsArg().name);
}
auto op_desc = GenGraphOpDesc(model_name, in_var_names, out_var_names);
auto op_desc = GenGraphOpDesc(weight_var_name, in_var_names, out_var_names);
auto graph_op = LiteOpRegistry::Global().Create("graph_op");
graph_op->Attach(op_desc, scope);
......@@ -91,6 +91,12 @@ void SubgraphProgramPass::InsertNewNode(
IR_OP_VAR_LINK(new_op_node, out_var);
}
// add weight node to store pre-compilied NPU model
auto new_weight_node = graph->NewArgumentNode(weight_var_name);
new_weight_node->AsArg().is_weight = true;
new_weight_node->AsArg().is_persist = true;
DirectedLink(new_weight_node, new_op_node);
// assign context
auto& inst = new_op_node->AsStmt();
inst.picked_kernel().SetContext(
......
......@@ -60,13 +60,13 @@ class SubgraphProgramPass : public ProgramPass {
const std::unique_ptr<SSAGraph>& graph);
// generate the graph op desc
cpp::OpDesc GenGraphOpDesc(const std::string& model_name,
cpp::OpDesc GenGraphOpDesc(const std::string& weight_var_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names);
// insert a new graph op node
void InsertNewNode(const std::unique_ptr<SSAGraph>& graph,
const std::string& model_name,
const std::string& weight_var_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
......
......@@ -106,7 +106,7 @@ class Optimizer {
"runtime_context_assign_pass",
"argument_type_display_pass", //
#ifndef LITE_WITH_OPENCL
#if !defined(LITE_WITH_OPENCL) && !defined(LITE_WITH_NPU)
// TODO(ysh329): cause CL_INVALID_MEM_OBJECT when setArg in kernel
"memory_optimize_pass",
#endif
......
......@@ -5,5 +5,5 @@ endif()
message(STATUS "compile with lite NPU kernels")
add_kernel(graph_compute_npu NPU basic SRCS graph_compute.cc DEPS ${lite_kernel_deps} ${npu_ddk_libs})
add_kernel(graph_compute_npu NPU basic SRCS graph_compute.cc DEPS ${lite_kernel_deps} npu_runtime)
# lite_cc_test(test_graph_compute_npu SRCS graph_compute_test.cc DEPS graph_compute_npu)
......@@ -30,10 +30,16 @@ void GraphCompute::PrepareForRun() {
auto& ctx = this->ctx_->template As<NPUContext>();
auto& param = this->Param<param_t>();
exec_ = ctx.client(param.model_name);
CHECK(exec_);
CHECK(param.weight);
CHECK(lite::npu::LoadModel(*param.weight, &model_client_, &model_name_));
// TODO(hong19860320): find an good way to free the model data.
// No interface exists to free the data of tensor, so I resize the dim to 1
// and change target to force it to realloc a small size memory.
param.weight->Resize({1});
param.weight->mutable_data<int8_t>(TargetType::kARM);
CHECK(model_client_);
int ret =
exec_->GetModelIOTensorDim(param.model_name, npu_idims_, npu_odims_);
model_client_->GetModelIOTensorDim(model_name_, npu_idims_, npu_odims_);
CHECK_EQ(ret, hiai::AI_SUCCESS) << "[NPU] Get dims failed.";
npu_itensors_.resize(npu_idims_.size());
......@@ -108,7 +114,7 @@ void GraphCompute::Run() {
sizeof(float) * static_cast<size_t>(itensor->dims().production()));
}
std::string key = "model_name"; // Note: key seems must be model_name
npu_context_.AddPara(key, param.model_name);
model_context_.AddPara(key, model_name_);
auto GetCurrentUS = []() -> double {
struct timeval time;
......@@ -117,9 +123,9 @@ void GraphCompute::Run() {
};
int istamp;
auto start_time = GetCurrentUS();
CHECK_EQ(
hiai::AI_SUCCESS,
exec_->Process(npu_context_, npu_itensors_, npu_otensors_, 1000, istamp));
CHECK_EQ(hiai::AI_SUCCESS,
model_client_->Process(
model_context_, npu_itensors_, npu_otensors_, 1000, istamp));
LOG(INFO) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us";
for (size_t i = 0; i < param.outputs.size(); ++i) {
......@@ -147,5 +153,6 @@ REGISTER_LITE_KERNEL(graph_op,
paddle::lite::kernels::npu::GraphCompute,
def)
.BindInput("Inputs", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Outputs", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
......@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "lite/core/kernel.h"
......@@ -39,15 +40,15 @@ class GraphCompute : public KernelLite<TARGET(kNPU), PRECISION(kFloat)> {
bool input_dims_changed() const;
private:
hiai::AiModelMngerClient* exec_;
std::shared_ptr<hiai::AiModelMngerClient> model_client_;
std::string model_name_;
hiai::AiContext model_context_;
std::vector<hiai::TensorDimension> npu_idims_;
std::vector<hiai::TensorDimension> npu_odims_;
std::vector<std::shared_ptr<hiai::AiTensor>> npu_itensors_;
std::vector<std::shared_ptr<hiai::AiTensor>> npu_otensors_;
// TODO(TJ): find better place
hiai::AiContext npu_context_;
};
} // namespace npu
......
......@@ -28,7 +28,6 @@ lite_cc_library(model_parser SRCS model_parser.cc DEPS
target_wrapper_host
compatible_pb
memory
CUDA_DEPS target_wrapper_cuda
NPU_DEPS npu_helper)
CUDA_DEPS target_wrapper_cuda)
lite_cc_test(test_compatible_pb SRCS compatible_pb_test.cc DEPS compatible_pb)
......@@ -31,10 +31,6 @@
#endif
#include "lite/utils/io.h"
#ifdef LITE_WITH_NPU
#include "lite/backends/npu/npu_helper.h"
#endif
namespace paddle {
namespace lite {
......@@ -266,25 +262,6 @@ void LoadModelPb(const std::string &model_dir,
}
}
#ifdef LITE_WITH_NPU
auto main_block = pb_proto_prog.blocks(0);
for (auto &op : main_block.ops()) {
LOG(INFO) << "op type:" << op.type();
if (op.type() != "graph_op") {
continue;
}
auto xs = op.attrs();
auto it = std::find_if(
xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) {
return x.name() == "model_name";
});
CHECK(it != xs.end());
auto model_name = it->s();
std::string file_path = model_dir + "/" + model_name;
CHECK(npu::BuildNPUClient(file_path, model_name))
<< "NPU model load failed!";
}
#endif
VLOG(4) << "Load protobuf model in '" << model_dir << "'' successfully";
}
......@@ -737,21 +714,6 @@ void LoadModelNaive(const std::string &model_dir,
}
}
#ifdef LITE_WITH_NPU
auto &prog = *cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.OpsSize(); ++i) {
auto &op = *main_block_desc.GetOp<cpp::OpDesc>(i);
if (op.Type() != "graph_op") {
continue;
}
auto model_name = op.GetAttr<std::string>("model_name");
std::string file_path = model_dir + "/" + model_name;
CHECK(npu::BuildNPUClient(file_path, model_name))
<< "NPU model load failed!";
}
#endif
VLOG(4) << "Load naive buffer model in '" << model_dir << "' successfully";
}
......@@ -783,10 +745,6 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,
std::string combined_params_path = param_buffer;
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, true);
#ifdef LITE_WITH_NPU
LOG(FATAL) << "load from memory is not supported by NPU";
#endif
VLOG(4) << "Load model from naive buffer memory successfully";
}
......
......@@ -29,6 +29,7 @@ bool GraphOpLite::InferShape() const { return CheckShape(); /* enrich me */ }
bool GraphOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto inputs = op_desc.Input("Inputs");
auto weight = op_desc.Input("Weight");
auto outputs = op_desc.Output("Outputs");
for (auto var : inputs) {
......@@ -36,12 +37,14 @@ bool GraphOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.inputs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
param_.weight = scope->FindVar(weight.front())->GetMutable<lite::Tensor>();
CHECK(param_.weight);
for (auto var : outputs) {
CHECK(scope->FindVar(var));
param_.outputs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
param_.model_name = op_desc.GetAttr<std::string>("model_name");
return true;
}
......
......@@ -70,8 +70,8 @@ struct CalibParam {
struct GraphParam {
std::vector<const lite::Tensor*> inputs{};
lite::Tensor* weight{};
std::vector<lite::Tensor*> outputs{};
std::string model_name{"model"};
};
/// -------------------------- NN operators ------------------------------------
......
#!/bin/bash
set -ex
# global variables with default value
ARM_OS="android" # android only yet
ARM_ABI="armv8" # armv8, armv7
ARM_LANG="gcc" # gcc only yet
ANDROID_STL="c++_shared" # c++_shared, c++_static
DDK_ROOT="$(pwd)/ai_ddk_lib/" # HIAI SDK from https://developer.huawei.com/consumer/cn/hiai/
TARGET_NAME="test_npu_pass" # default target
BUILD_EXTRA=OFF # ON(with sequence ops)/OFF
WITH_JAVA=ON # ON(build jar and jni so)/OFF
WITH_TESTING=ON # ON/OFF
ON_TINY_PUBLISH=OFF # ON(tiny publish)/OFF(full publish)
function print_usage {
echo -e "\nUSAGE:"
echo
echo "----------------------------------------"
echo -e "--arm_os=<os> android only yet."
echo -e "--arm_abi=<abi> armv8, armv7 yet."
echo -e "--android_stl=<shared> shared or static"
echo -e "--arm_lang=<gcc> "
echo -e "--ddk_root=<hiai_ddk_root> "
echo -e "--test_name=<test_name>"
echo -e "--android_stl=<shared> c++_shared or c++_static"
echo -e "--arm_lang=<gcc>"
echo -e "--ddk_root=<hiai_ddk_root>"
echo -e "--target_name=<target_name>"
echo "----------------------------------------"
echo
}
......@@ -47,80 +59,54 @@ function prepare_thirdparty {
fi
}
function cmake_npu {
prepare_workspace
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
# $4: ANDROID_STL_TYPE in "c++_shared" "c++_static"
# $5: DDK_ROOT path
function build_npu {
cur_dir=$(pwd)
prepare_thirdparty
local stl_dir
local publish_dir
# the c++ symbol is not recognized by the bundled script
if [[ "${ANDROID_STL}" == "c++_shared" ]]; then
stl_dir="cxx_shared"
fi
if [[ "${ANDROID_STL}" == "c++_static" ]]; then
stl_dir="cxx_static"
fi
if [[ "${ON_TINY_PUBLISH}" == "ON" ]]; then
WITH_TESTING=OFF
publish_dir="tiny_publish"
else
publish_dir="full_publish"
fi
build_dir=$cur_dir/build.lite.npu.${ARM_OS}.${ARM_ABI}.${ARM_LANG}.${stl_dir}.${publish_dir}
mkdir -p $build_dir
cd $build_dir
# NPU libs need API LEVEL 24 above
prepare_workspace
cmake .. \
-DWITH_GPU=OFF \
-DWITH_MKL=OFF \
-DWITH_LITE=ON \
-DLITE_WITH_CUDA=OFF \
-DLITE_WITH_X86=OFF \
-DLITE_BUILD_EXTRA=ON \
-DLITE_BUILD_EXTRA=${BUILD_EXTRA} \
-DLITE_WITH_ARM=ON \
-DWITH_ARM_DOTPROD=ON \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DWITH_TESTING=ON \
-DLITE_WITH_JAVA=ON \
-DWITH_TESTING=${WITH_TESTING} \
-DLITE_WITH_JAVA=${WITH_JAVA} \
-DLITE_WITH_NPU=ON \
-DLITE_ON_TINY_PUBLISH=${ON_TINY_PUBLISH} \
-DANDROID_API_LEVEL=24 \
-DARM_TARGET_OS=$1 \
-DARM_TARGET_ARCH_ABI=$2 \
-DARM_TARGET_LANG=$3 \
-DANDROID_STL_TYPE=$4 \
-DNPU_DDK_ROOT=$5
}
function build_npu {
# os, abi, lang, stl, ddk_root, test_name
cur_dir=$(pwd)
local os=android
local abi=armv8
local lang=gcc
local stl="c++_shared"
local ddk_root="${cur_dir}/ai_ddk_lib/"
local test_name=test_npu_pass
prepare_thirdparty
if [ "x${ARM_OS}" != "x" ]; then
os=$ARM_OS
fi
if [[ "x${ARM_ABI}" != "x" ]]; then
abi=$ARM_ABI
fi
if [[ "x${ARM_LANG}" != "x" ]]; then
lang=$ARM_LANG
fi
if [[ "x${ANDROID_STL}" != "x" ]]; then
stl=$ANDROID_STL
fi
if [[ "x${DDK_ROOT}" != "x" ]]; then
ddk_root=$DDK_ROOT
fi
if [[ $# -ge 1 ]]; then
test_name=$1
fi
# the c++ symbol is not recognized by the bundled script
if [[ "${stl}" == "c++_shared" ]]; then
stl_dir="cxx_shared"
fi
if [[ "${stl}" == "c++_static" ]]; then
stl_dir="cxx_static"
fi
build_dir=$cur_dir/build.lite.npu.${os}.${abi}.${lang}.${stl_dir}
mkdir -p $build_dir
cd $build_dir
-DARM_TARGET_OS=${ARM_OS} \
-DARM_TARGET_ARCH_ABI=${ARM_ABI} \
-DARM_TARGET_LANG=${ARM_LANG} \
-DANDROID_STL_TYPE=${ANDROID_STL} \
-DNPU_DDK_ROOT=${DDK_ROOT}
cmake_npu ${os} ${abi} ${lang} ${stl} ${ddk_root}
make $test_name -j8
make $TARGET_NAME -j2
cd -
echo "Done"
......@@ -130,12 +116,8 @@ function main {
# Parse command line.
for i in "$@"; do
case $i in
--tests=*)
TESTS_FILE="${i#*=}"
shift
;;
--test_name=*)
TEST_NAME="${i#*=}"
--target_name=*)
TARGET_NAME="${i#*=}"
shift
;;
--arm_os=*)
......@@ -154,16 +136,27 @@ function main {
ANDROID_STL="${i#*=}"
shift
;;
--build_extra=*)
BUILD_EXTRA="${i#*=}"
shift
;;
--ddk_root=*)
DDK_ROOT="${i#*=}"
shift
;;
build)
build_npu $TEST_NAME
build_npu
shift
;;
full_publish)
build_npu publish_inference
TARGET_NAME=publish_inference
build_npu
shift
;;
tiny_publish)
ON_TINY_PUBLISH=ON
TARGET_NAME=publish_inference
build_npu
shift
;;
*)
......
......@@ -32,12 +32,24 @@ ostream& ostream::operator<<(const char* obj) {
return *this;
}
template <>
ostream& ostream::operator<<(const char& obj) {
_data = _data + obj;
return *this;
}
template <>
ostream& ostream::operator<<(const std::string& obj) {
_data = _data + obj;
return *this;
}
template <>
ostream& ostream::operator<<(const int16_t& obj) {
ADD_DATA_AS_STRING(_data, obj);
return *this;
}
template <>
ostream& ostream::operator<<(const int& obj) {
ADD_DATA_AS_STRING(_data, obj);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册