未验证 提交 1e10b471 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU][XPU] Add kernel context to NPU/XPU subgraph engine (#2686)

上级 ad1dfbf2
...@@ -28,12 +28,14 @@ namespace subgraph { ...@@ -28,12 +28,14 @@ namespace subgraph {
class Engine { class Engine {
public: public:
Engine(int block_idx, Engine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc, cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
lite::Scope *scope) lite::Scope *scope)
: block_idx_(block_idx), : ctx_(ctx),
block_idx_(block_idx),
block_desc_(block_desc), block_desc_(block_desc),
input_names_(input_names), input_names_(input_names),
output_names_(output_names), output_names_(output_names),
...@@ -55,6 +57,7 @@ class Engine { ...@@ -55,6 +57,7 @@ class Engine {
virtual bool InputShapeChanged(); virtual bool InputShapeChanged();
KernelContext *ctx_{nullptr};
int block_idx_; int block_idx_;
cpp::BlockDesc *block_desc_; cpp::BlockDesc *block_desc_;
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
......
...@@ -207,7 +207,8 @@ int SubgraphEngine::LaunchDeviceProgram() { ...@@ -207,7 +207,8 @@ int SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(param.sub_block_idx, engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx,
param.sub_block_desc, param.sub_block_desc,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names,
......
...@@ -29,13 +29,14 @@ namespace npu { ...@@ -29,13 +29,14 @@ namespace npu {
class SubgraphEngine : public subgraph::Engine { class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(int block_idx, SubgraphEngine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc, cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
Scope *scope) Scope *scope)
: subgraph::Engine( : subgraph::Engine(
block_idx, block_desc, input_names, output_names, scope) {} ctx, block_idx, block_desc, input_names, output_names, scope) {}
protected: protected:
int BuildDeviceProgram() override; int BuildDeviceProgram() override;
......
...@@ -197,7 +197,8 @@ int SubgraphEngine::LaunchDeviceProgram() { ...@@ -197,7 +197,8 @@ int SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(param.sub_block_idx, engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx,
param.sub_block_desc, param.sub_block_desc,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names,
......
...@@ -29,13 +29,14 @@ namespace xpu { ...@@ -29,13 +29,14 @@ namespace xpu {
class SubgraphEngine : public subgraph::Engine { class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(int block_idx, SubgraphEngine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc, cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
Scope *scope) Scope *scope)
: subgraph::Engine( : subgraph::Engine(
block_idx, block_desc, input_names, output_names, scope) {} ctx, block_idx, block_desc, input_names, output_names, scope) {}
protected: protected:
int BuildDeviceProgram() override; int BuildDeviceProgram() override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册