diff --git a/lite/kernels/npu/bridges/engine.h b/lite/kernels/npu/bridges/engine.h index db39063417d7023d697639236043a66c442ca8fa..61a4e12cf3ad6e3eab608a585f165fde9dec081d 100644 --- a/lite/kernels/npu/bridges/engine.h +++ b/lite/kernels/npu/bridges/engine.h @@ -28,12 +28,14 @@ namespace subgraph { class Engine { public: - Engine(int block_idx, + Engine(KernelContext *ctx, + int block_idx, cpp::BlockDesc *block_desc, const std::vector &input_names, const std::vector &output_names, lite::Scope *scope) - : block_idx_(block_idx), + : ctx_(ctx), + block_idx_(block_idx), block_desc_(block_desc), input_names_(input_names), output_names_(output_names), @@ -55,6 +57,7 @@ class Engine { virtual bool InputShapeChanged(); + KernelContext *ctx_{nullptr}; int block_idx_; cpp::BlockDesc *block_desc_; std::vector input_names_; diff --git a/lite/kernels/npu/subgraph_compute.cc b/lite/kernels/npu/subgraph_compute.cc index c6cbea46fafc5bd6e3d7431be23fbea8bf1c93fa..d9b191950668660ae2b76b70ac2b5c12aece92c0 100644 --- a/lite/kernels/npu/subgraph_compute.cc +++ b/lite/kernels/npu/subgraph_compute.cc @@ -207,7 +207,8 @@ int SubgraphEngine::LaunchDeviceProgram() { void SubgraphCompute::PrepareForRun() { auto& param = this->Param(); - engine_.reset(new SubgraphEngine(param.sub_block_idx, + engine_.reset(new SubgraphEngine(ctx_.get(), + param.sub_block_idx, param.sub_block_desc, param.input_data_names, param.output_data_names, diff --git a/lite/kernels/npu/subgraph_compute.h b/lite/kernels/npu/subgraph_compute.h index dd0bf82bc9e743287d2ad4cb81db9a5fdd57c276..27b4a36cfeadf6cca328fb9c980d53c9c5e79095 100644 --- a/lite/kernels/npu/subgraph_compute.h +++ b/lite/kernels/npu/subgraph_compute.h @@ -29,13 +29,14 @@ namespace npu { class SubgraphEngine : public subgraph::Engine { public: - SubgraphEngine(int block_idx, + SubgraphEngine(KernelContext *ctx, + int block_idx, cpp::BlockDesc *block_desc, const std::vector &input_names, const std::vector &output_names, Scope *scope) : subgraph::Engine( - block_idx, block_desc, input_names, output_names, scope) {} + ctx, block_idx, block_desc, input_names, output_names, scope) {} protected: int BuildDeviceProgram() override; diff --git a/lite/kernels/xpu/subgraph_compute.cc b/lite/kernels/xpu/subgraph_compute.cc index 0a7a4d2aa5431d04c19f531ca118ac422417cbba..07a74b045477bcdff0d60913f20e79ff8497705b 100644 --- a/lite/kernels/xpu/subgraph_compute.cc +++ b/lite/kernels/xpu/subgraph_compute.cc @@ -197,7 +197,8 @@ int SubgraphEngine::LaunchDeviceProgram() { void SubgraphCompute::PrepareForRun() { auto& param = this->Param(); - engine_.reset(new SubgraphEngine(param.sub_block_idx, + engine_.reset(new SubgraphEngine(ctx_.get(), + param.sub_block_idx, param.sub_block_desc, param.input_data_names, param.output_data_names, diff --git a/lite/kernels/xpu/subgraph_compute.h b/lite/kernels/xpu/subgraph_compute.h index 2196eb3621d1acb6fb6c76426118d150a8228214..c21a1b7b054fd642f330ee95bff972f581e65c6b 100644 --- a/lite/kernels/xpu/subgraph_compute.h +++ b/lite/kernels/xpu/subgraph_compute.h @@ -29,13 +29,14 @@ namespace xpu { class SubgraphEngine : public subgraph::Engine { public: - SubgraphEngine(int block_idx, + SubgraphEngine(KernelContext *ctx, + int block_idx, cpp::BlockDesc *block_desc, const std::vector &input_names, const std::vector &output_names, Scope *scope) : subgraph::Engine( - block_idx, block_desc, input_names, output_names, scope) {} + ctx, block_idx, block_desc, input_names, output_names, scope) {} protected: int BuildDeviceProgram() override;