未验证 提交 6478c30e 编写于 作者: H HappyAngel 提交者: GitHub

Merge pull request #143 from PaddlePaddle/develop

pull code
......@@ -49,18 +49,33 @@ class LITE_API Predictor {
program_desc_ = std::make_shared<cpp::ProgramDesc>();
}
// Create a predictor with the weight variable scope set.
///////////////////////////////////////////////////////////////////
// Function: Predictor
// Usage: Constructor of Predictor. Create a predictor with the
// weight variable scope set given.
///////////////////////////////////////////////////////////////////
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {}
///////////////////////////////////////////////////////////////////
// Function: Predictor
// Usage: Constructor of Predictor. This constructor function can
// only be called in Predictor->Clone. This Function will create
// a predictor from existed ProgramDesc, Scope and RuntimeProgram.
///////////////////////////////////////////////////////////////////
Predictor(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root_scope,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places,
const std::vector<std::string>& vars_to_clone = {})
: program_desc_(program_desc), scope_(root_scope) {
Program program(program_desc_, scope_, valid_places, vars_to_clone);
optimizer_ = Optimizer(std::move(program), valid_places);
exec_scope_ = optimizer_.exec_scope();
const std::vector<std::string>& var_names = {})
: program_desc_(program_desc), scope_(root) {
// step1. Create a Program to construct the exec_scope and ops
Program program(program_desc_, scope_, valid_places, var_names);
exec_scope_ = program.exec_scope();
valid_places_ = valid_places;
// step3. Create the RuntimeProgram.
program_.reset(
new RuntimeProgram(program_desc_, exec_scope_, kRootBlockIdx));
program_generated_ = true;
}
// Build from a model, with places set for hardware config.
......@@ -83,26 +98,58 @@ class LITE_API Predictor {
const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {});
std::shared_ptr<Predictor> Clone() const {
return std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
//////////////////////////////////////////////////////////
// Function: Clone
// Usage: Create a Predictor from an existed one,
// the cloned predictor will share persistable variables
// in scope_ with the original predictor.
//////////////////////////////////////////////////////////
std::shared_ptr<Predictor> Clone() {
// step 1. Generate runtime_program, update op_info and var_info in
// program_desc_
if (!program_generated_) {
GenRuntimeProgram();
}
program_->SaveToProgram(program_desc_);
// step 2. Create a predictor friom current program_desc_ and
// runtime_program.
auto predictor =
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
// step3. Return the result
return predictor;
}
std::shared_ptr<Predictor> Clone(
const std::vector<std::string>& vars_to_clone) const {
//////////////////////////////////////////////////////////
// Function: Clone(var_names)
// Usage: Create a Predictor from an existed one,
// the cloned predictor will share persistable variables
// but persistable variables of name var_names will not
// be shared.
//////////////////////////////////////////////////////////
std::shared_ptr<Predictor> Clone(const std::vector<std::string>& var_names) {
CHECK(program_desc_) << "Both program and scope of current predicotr "
"should be not be nullptr in Clone mode.";
CHECK(scope_) << "Both program and scope of current predicotr should be "
"not be nullptr in Clone mode.";
// step 1. Generate runtime_program, update op_info and var_info in
// program_desc_
if (!program_generated_) {
GenRuntimeProgram();
}
program_->SaveToProgram(program_desc_);
// step 2. Create a predictor friom current program_desc_ and
// runtime_program.
auto predictor = std::make_shared<Predictor>(
program_desc_, scope_, valid_places_, vars_to_clone);
for (auto var_name : vars_to_clone) {
program_desc_, scope_, valid_places_, var_names);
// step3. Copy some persistable variables into private scope.
for (auto var_name : var_names) {
predictor->exec_scope_->LocalVar(var_name);
auto* tensor = predictor->scope_->Var(var_name)->GetMutable<Tensor>();
auto* tensor =
predictor->scope_->Var(var_name)->GetMutable<lite::Tensor>();
auto* sub_tensor =
predictor->exec_scope_->Var(var_name)->GetMutable<Tensor>();
sub_tensor->CopyDataFrom(*tensor);
}
// step4. Return the result
return predictor;
}
......@@ -161,7 +208,7 @@ class LITE_API Predictor {
std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::shared_ptr<Scope> scope_;
Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_;
std::shared_ptr<RuntimeProgram> program_;
bool program_generated_{false};
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
......
......@@ -33,6 +33,7 @@ void LightPredictor::Build(const std::string& lite_model_file,
DequantizeWeight();
BuildRuntimeProgram(program_desc_);
PrepareFeedFetch();
program_desc_.reset();
}
void LightPredictor::Build(const std::string& model_dir,
......
......@@ -97,7 +97,7 @@ void TestModel(const std::vector<Place>& valid_places,
if (first_target == TARGET(kOpenCL) || first_target == TARGET(kNPU)) {
ASSERT_EQ(out->dims().production(), 1000);
double eps = first_target == TARGET(kOpenCL) ? 0.15 : 0.1;
double eps = first_target == TARGET(kOpenCL) ? 0.25 : 0.1;
for (int i = 0; i < ref.size(); ++i) {
for (int j = 0; j < ref[i].size(); ++j) {
auto result = pdata[j * step + (out->dims()[1] * i)];
......
......@@ -47,21 +47,18 @@ struct Program {
Program(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root_scope,
const std::vector<Place>& valid_places,
const std::vector<std::string>& vars_to_clone = {})
: scope_(root_scope),
valid_places_(valid_places),
program_desc_(program_desc) {
const std::vector<std::string>& var_names = {})
: scope_(root_scope), valid_places_(valid_places) {
CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work";
PrepareWorkspace(program_desc_, vars_to_clone);
PrepareWorkspace(program_desc, var_names);
VLOG(4) << "build desc";
Build(program_desc_);
Build(program_desc);
VLOG(4) << "build desc finished";
}
std::unique_ptr<Program> Clone() const {
return std::unique_ptr<Program>(
new Program(program_desc_, scope_, valid_places_));
return std::unique_ptr<Program>(new Program(scope_));
}
const std::list<std::string>& weights() const { return weights_; }
......@@ -83,8 +80,6 @@ struct Program {
Scope* exec_scope() { return exec_scope_; }
Scope* scope() { return scope_.get(); }
cpp::ProgramDesc* program_desc() { return program_desc_.get(); }
const std::map<std::string, const Type*>& var_type_map() const {
return var_type_map_;
}
......@@ -106,7 +101,6 @@ struct Program {
std::vector<Place> valid_places_;
// Runtime scope.
Scope* exec_scope_{};
std::shared_ptr<cpp::ProgramDesc> program_desc_;
};
struct Instruction {
......
......@@ -88,7 +88,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto i_data = param.input->data<float>();
auto o_data = param.output->mutable_data<float>();
auto w_data = flag_gemm_ ? param.w->data<float>() : weights_.data<float>();
auto w_data = param.w->data<float>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......@@ -149,8 +149,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto i_data = param.input->data<int8_t>();
auto o_data = param.output->mutable_data<float>();
auto w_data =
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
auto w_data = param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......@@ -208,8 +207,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto i_data = param.input->data<int8_t>();
auto o_data = param.output->mutable_data<int8_t>();
auto w_data =
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
auto w_data = param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......
......@@ -104,9 +104,11 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> {
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
flag_gemm_ = check_fc_use_gemm<PType, OutType>(
m_, param.weight_scale, param.bias != nullptr);
if (!flag_trans_weights_ && !flag_gemm_) {
flag_trans_weights_ = true;
fc_trans_weights<PType>(*param.w, &weights_);
if (flag_trans_weights_ == flag_gemm_) {
flag_trans_weights_ = !flag_trans_weights_;
Tensor tmp_tensor;
fc_trans_weights<PType>(*param.w, &tmp_tensor);
param.w->CopyDataFrom(tmp_tensor);
}
}
......@@ -117,7 +119,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> {
private:
DDim last_shape_;
Tensor weights_;
Tensor bias_;
bool flag_trans_weights_{false};
bool flag_trans_bias_{false};
......
......@@ -25,135 +25,209 @@ namespace lite {
namespace kernels {
namespace arm {
bool IsShuffleChannel(const std::vector<int> &axis) {
bool is_shuffle_channel = true;
if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) {
for (int i = 3; i < axis.size(); ++i) {
if (axis[i] != i) {
is_shuffle_channel = false;
break;
}
template <typename Dtype>
void trans_basic(const int count,
const Dtype* din,
const int* permute_order,
const int* old_steps,
const int* new_steps,
const int num_axes,
Dtype* dout) {
for (int i = 0; i < count; ++i) {
int old_idx = 0;
int idx = i;
for (int j = 0; j < num_axes; ++j) {
int order = permute_order[j];
old_idx += (idx / new_steps[j]) * old_steps[order];
idx %= new_steps[j];
}
} else {
return false;
dout[i] = din[old_idx];
}
return is_shuffle_channel;
}
template <typename Dtype>
void ShuffleChannelCompute(const std::vector<int> &axis,
const lite::Tensor *input,
lite::Tensor *output) {
const Dtype *input_ptr = input->data<Dtype>();
Dtype *output_ptr = output->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const DDim &in_dim = input->dims();
const DDim &out_dim = output->dims();
size_t offset = 1;
for (int i = 3; i < axis.size(); ++i) {
offset *= in_dim[i];
}
void transpose_mat(const Dtype* din,
Dtype* dout,
const int num,
const int width,
const int height);
void transpose_mat(const float* din,
float* dout,
const int num,
const int width,
const int height) {
int nw = width >> 2;
int nh = height >> 2;
int size_in = width * height;
#pragma omp parallel for collapse(3)
for (int batch = 0; batch < out_dim[0]; ++batch) {
for (int c1 = 0; c1 < out_dim[1]; ++c1) {
for (int c2 = 0; c2 < out_dim[2]; ++c2) {
size_t out_offset =
((batch * out_dim[1] + c1) * out_dim[2] + c2) * offset;
size_t in_offset = ((batch * in_dim[1] + c2) * in_dim[2] + c1) * offset;
memcpy(output_ptr + out_offset,
input_ptr + in_offset,
offset * sizeof(Dtype));
for (int i = 0; i < num; ++i) {
float* ptr_out = dout + i * size_in;
const float* ptr_in = din + i * size_in;
#pragma omp parallel for
for (int h = 0; h < nh; h++) {
const float* ptr_din_row = ptr_in + h * 4 * width;
for (int w = 0; w < nw; w++) {
float* data_out_ptr = ptr_out + w * 4 * height + h * 4;
const float* din0 = ptr_din_row;
const float* din1 = din0 + width;
const float* din2 = din1 + width;
const float* din3 = din2 + width;
float* dout0 = data_out_ptr;
float* dout1 = dout0 + height;
float* dout2 = dout1 + height;
float* dout3 = dout2 + height;
#ifdef __aarch64__
float32x4_t vr0 = vld1q_f32(din0);
float32x4_t vr1 = vld1q_f32(din1);
float32x4_t vr2 = vld1q_f32(din2);
float32x4_t vr3 = vld1q_f32(din3);
float32x4_t re0 = vtrn1q_f32(vr0, vr1);
float32x4_t re1 = vtrn2q_f32(vr0, vr1);
float32x4_t re2 = vtrn1q_f32(vr2, vr3);
float32x4_t re3 = vtrn2q_f32(vr2, vr3);
vst1_f32(dout0, vget_low_f32(re0));
dout0 += 2;
vst1_f32(dout0, vget_low_f32(re2));
vst1_f32(dout1, vget_low_f32(re1));
dout1 += 2;
vst1_f32(dout1, vget_low_f32(re3));
vst1_f32(dout2, vget_high_f32(re0));
dout2 += 2;
vst1_f32(dout2, vget_high_f32(re2));
vst1_f32(dout3, vget_high_f32(re1));
dout3 += 2;
vst1_f32(dout3, vget_high_f32(re3));
#else
asm("vld1.32 {d0, d1}, [%[in0]] \n"
"vld1.32 {d2, d3}, [%[in1]] \n"
"vld1.32 {d4, d5}, [%[in2]] \n"
"vld1.32 {d6, d7}, [%[in3]] \n"
"vtrn.32 q0, q1 \n"
"vtrn.32 q2, q3 \n"
"vswp d1, d4 \n"
"vswp d3, d6 \n"
"vst1.32 {d0, d1}, [%[out0]] \n"
"vst1.32 {d2, d3}, [%[out1]] \n"
"vst1.32 {d4, d5}, [%[out2]] \n"
"vst1.32 {d6, d7}, [%[out3]] \n"
:
: [out0] "r"(dout0),
[out1] "r"(dout1),
[out2] "r"(dout2),
[out3] "r"(dout3),
[in0] "r"(din0),
[in1] "r"(din1),
[in2] "r"(din2),
[in3] "r"(din3)
: "q0", "q1", "q2", "q3");
#endif
ptr_din_row += 4;
}
}
// remian
for (int h = 0; h < height; h++) {
for (int w = nw * 4; w < width; w++) {
const float* data_in_ptr = ptr_in + h * width + w;
float* data_out_ptr = ptr_out + w * height + h;
*data_out_ptr = *data_in_ptr;
}
}
for (int w = 0; w < width; w++) {
for (int h = nh * 4; h < height; h++) {
const float* data_in_ptr = ptr_in + h * width + w;
float* data_out_ptr = ptr_out + w * height + h;
*data_out_ptr = *data_in_ptr;
}
}
}
}
template <typename Dtype>
void TransposeCompute_(const std::vector<int> &axis,
const lite::Tensor *input,
lite::Tensor *output) {
// const Dtype *input_ptr = input->data<Dtype>();
const Dtype *input_ptr = input->data<float>();
Dtype *output_ptr = output->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const DDim &in_dim = input->dims();
const DDim &out_dim = output->dims();
// precompute inverted output dim and strides
size_t rout_dim[6], strides[6];
int permute = axis.size(); // permute must >=2 && <= 6.
for (int i = 0; i < permute; ++i) {
int k = permute - 1 - i;
strides[k] = 1;
for (int j = axis[i] + 1; j < permute; ++j) {
strides[k] *= in_dim[j];
std::vector<int> get_stride(const paddle::lite::DDimLite& dims) {
std::vector<int> data_stride{0};
for (int i = 0; i < dims.size(); ++i) {
data_stride.push_back(dims.count(i, dims.size()));
}
return data_stride;
}
void TransposeCompute::PrepareForRun() {
auto& param = Param<operators::TransposeParam>();
auto* input = param.x;
auto* output = param.output;
int _num_axes = input->dims().size();
CHECK(_num_axes == param.axis.size())
<< "axis size is not match to input dims";
need_trans = false;
for (int i = 0; i < _num_axes; ++i) {
if (param.axis[i] != i) {
need_trans = true;
break;
}
rout_dim[k] = out_dim[i];
}
// unroll the first 2 dimensions
int reamin_dim = 1;
for (int i = 2; i < out_dim.size(); ++i) {
reamin_dim *= out_dim[i];
if (!need_trans) {
return;
}
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < out_dim[0]; ++batch) {
for (int j = 0; j < out_dim[1]; ++j) {
size_t offset = batch * strides[permute - 1] + j * strides[permute - 2];
Dtype *out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim;
int indics[4] = {0, 0, 0, 0};
for (int k = 0; k < reamin_dim; ++k) {
out_ptr[k] = input_ptr[offset];
indics[0] += 1;
offset += strides[0];
for (int p = 0; p < permute - 3; ++p) {
if (indics[p] == rout_dim[p]) {
indics[p + 1] += 1;
indics[p] = 0;
offset += strides[p + 1];
offset -= rout_dim[p] * strides[p];
} else {
break;
}
}
}
std::vector<int> axis_diff;
int j = 0;
for (int i = 0; i < _num_axes; ++i) {
if (param.axis[j] != i) {
axis_diff.push_back(j);
} else {
j++;
}
}
}
for (int i = 0; i < axis_diff.size(); i++) {
}
if (input->dims().count(axis_diff[0], _num_axes) == 1) {
need_trans = false;
return;
}
// Transpose
void TransposeCompute::Run() {
auto &param = Param<operators::TransposeParam>();
auto *input = param.x;
auto *output = param.output;
const std::vector<int> axis = param.axis;
if (axis_diff.size() == 1) {
trans_mat = true;
_trans_num = input->dims().count(0, std::max(axis_diff[0], 0));
_trans_w = input->dims().count(axis_diff[0] + 1, _num_axes);
_trans_h = input->dims()[axis_diff[0]];
bool shuffle_channel = IsShuffleChannel(axis);
if (shuffle_channel) {
ShuffleChannelCompute<float>(axis, input, output);
} else {
TransposeCompute_<float>(axis, input, output);
trans_mat = false;
_new_steps = get_stride(output->dims());
_old_steps = get_stride(input->dims());
}
return;
}
// Transpose2
void Transpose2Compute::Run() {
auto &param = Param<operators::TransposeParam>();
auto *input = param.x;
auto *output = param.output;
// Transpose
void TransposeCompute::Run() {
auto& param = Param<operators::TransposeParam>();
auto* input = param.x;
auto* output = param.output;
const std::vector<int> axis = param.axis;
bool shuffle_channel = IsShuffleChannel(axis);
if (shuffle_channel) {
ShuffleChannelCompute<float>(axis, input, output);
//! only copy the data
if (!need_trans) {
output->CopyDataFrom(*input);
return;
}
const float* din = static_cast<const float*>(input->data<float>());
float* dout = static_cast<float*>(output->mutable_data<float>());
//! transpose the data
if (trans_mat) {
transpose_mat(din, dout, _trans_num, _trans_w, _trans_h);
} else {
TransposeCompute_<float>(axis, input, output);
trans_basic(output->numel(),
din,
param.axis.data(),
_old_steps.data(),
_new_steps.data(),
input->dims().size(),
dout);
}
return;
}
} // namespace arm
......
......@@ -14,6 +14,7 @@
#pragma once
#include <algorithm>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/operators/transpose_op.h"
......@@ -26,19 +27,24 @@ namespace arm {
class TransposeCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::TransposeParam;
void PrepareForRun() override;
void Run() override;
virtual ~TransposeCompute() = default;
private:
bool need_trans = false;
bool trans_mat = false;
int _trans_num;
int _trans_w;
int _trans_h;
std::vector<int> _new_steps;
std::vector<int> _old_steps;
};
// Transpose2
class Transpose2Compute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class Transpose2Compute : public TransposeCompute {
public:
using param_t = operators::TransposeParam;
void Run() override;
virtual ~Transpose2Compute() = default;
};
......
......@@ -107,6 +107,7 @@ TEST(transpose_arm, compute_shape_nchw) {
// run transpose_compute
transpose.SetParam(param);
transpose.PrepareForRun();
transpose.Run();
// run transpose_compute_ref
......@@ -173,6 +174,7 @@ TEST(transpose2_arm, compute_shape_nchw) {
// run transpose_compute
transpose2.SetParam(param);
transpose2.PrepareForRun();
transpose2.Run();
// run transpose_compute_ref
......@@ -183,8 +185,8 @@ TEST(transpose2_arm, compute_shape_nchw) {
auto* output_ref_data = output_ref.data<float>();
for (int i = 0;
i < input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
i += 4) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
i += 1) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 0);
}
}
......
......@@ -34,6 +34,9 @@ bool SubgraphEngine::BuildDeviceProgram() {
const auto& bridges = subgraph::Registry::Instance();
graph.CreateCompilerHandle();
auto& ctx = this->ctx_->template As<BMContext>();
for (size_t i = 0; i < input_names_.size(); i++) {
graph.AddNode(input_names_[i]);
}
if (!origin_program_) {
BuildOriginProgram();
}
......@@ -60,7 +63,7 @@ bool SubgraphEngine::BuildDeviceProgram() {
std::string net_name = "bmnet_f32bmodel";
auto unique_net_name = lite::subgraph::bm::UniqueName(net_name);
__bmcompile_opt(
graph.GetCompilerHandle(), const_cast<char*>(unique_net_name.c_str()), 2);
graph.GetCompilerHandle(), const_cast<char*>(unique_net_name.c_str()), 1);
void* bmodel_data = nullptr;
unsigned int data_size = 0;
bm_hd_ = static_cast<bm_handle_t>(ctx.GetHandle());
......
......@@ -104,7 +104,7 @@ bool DeviceProgram::LoadFromCacheFile(
}
bool DeviceProgram::BuildGraphAndCacheToFile(
const std::vector<Instruction>& origin_program,
RuntimeProgram* origin_program,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims,
......@@ -118,10 +118,14 @@ bool DeviceProgram::BuildGraphAndCacheToFile(
// Convert all of ops and their input vars and weights to HiAI IR nodes,
// then added them into the IR graph
int status = 0;
CHECK(!origin_program.empty()) << "no instructions";
subgraph::huawei_ascend_npu::Graph graph;
const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program) {
CHECK(origin_program)
<< "[HUAWEI_ASCEND_NPU] The origin program is not initialized!";
CHECK_GT(origin_program->instructions(kRootBlockIdx).size(), 0)
<< "[HUAWEI_ASCEND_NPU] No instructions found in the origin program!";
const auto& insts = origin_program->instructions(kRootBlockIdx);
for (auto& inst : insts) {
auto op = const_cast<OpLite*>(inst.op());
CHECK(op);
op->CheckShape();
......@@ -140,7 +144,8 @@ bool DeviceProgram::BuildGraphAndCacheToFile(
// Collect the input and output nodes of the IR graph
std::vector<ge::Operator> device_inodes;
for (size_t i = 0; i < input_names.size(); i++) {
CHECK(graph.Has(input_names[i]) && graph.Get(input_names[i])->is_data());
CHECK(graph.Has(input_names[i]));
CHECK(graph.Get(input_names[i])->is_data());
device_inodes.push_back(*graph.Get(input_names[i])->data());
}
std::vector<ge::Operator> device_onodes;
......@@ -379,7 +384,8 @@ bool SubgraphEngine::BuildDeviceProgram() {
ctx_->As<HuaweiAscendNPUContext>().SubgraphModelCacheDir();
auto device_id = ctx_->As<HuaweiAscendNPUContext>().HuaweiAscendDeviceID();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Get model cached dir: " << model_cache_dir;
VLOG(3) << "[HUAWEI_ASCEND_NPU] Get huawei ascend npu device id: "
<< device_id;
// Check and load if the cached model and configuration file exists
if (model_cache_dir.empty() ||
!device_program->LoadFromCacheFile(input_names_,
......@@ -390,11 +396,14 @@ bool SubgraphEngine::BuildDeviceProgram() {
// Build the model online, including converting the paddle ops to the HiAI
// IR nodes, building the HiAI IR graph to the om model, then load it as a
// new HiAI model manager client for inference.
if (origin_program_.empty()) {
if (!origin_program_) {
BuildOriginProgram();
}
CHECK(!origin_program_.empty()) << "no instructions";
if (!device_program->BuildGraphAndCacheToFile(origin_program_,
CHECK(origin_program_)
<< "[HUAWEI_ASCEND_NPU] The origin program is not initialized!";
CHECK_GT(origin_program_->instructions().size(), 0)
<< "[HUAWEI_ASCEND_NPU] No instructions found in the origin program!";
if (!device_program->BuildGraphAndCacheToFile(origin_program_.get(),
input_names_,
output_names_,
origin_idims_,
......@@ -443,11 +452,11 @@ bool SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx,
param.sub_block_desc,
param.block_idx,
param.program_desc,
param.exec_scope,
param.input_data_names,
param.output_data_names,
param.scope));
param.output_data_names));
CHECK(engine_);
}
......
......@@ -46,7 +46,7 @@ class DeviceProgram {
const std::string& model_cache_dir,
const int device_id);
bool BuildGraphAndCacheToFile(
const std::vector<Instruction>& origin_program,
RuntimeProgram* origin_program,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int64_t>>& origin_idims,
......@@ -80,12 +80,16 @@ class SubgraphEngine : public subgraph::Engine {
public:
SubgraphEngine(KernelContext* ctx,
int block_idx,
cpp::BlockDesc* block_desc,
const std::shared_ptr<const cpp::ProgramDesc>& program_desc,
Scope* exec_scope,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
Scope* scope)
: subgraph::Engine(
ctx, block_idx, block_desc, input_names, output_names, scope) {}
const std::vector<std::string>& output_names)
: subgraph::Engine(ctx,
block_idx,
program_desc,
exec_scope,
input_names,
output_names) {}
protected:
bool PrepareWorkspaceForDeviceProgram() override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册