From 1e10b471894ee7b48149e257eacb66e433445008 Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Fri, 27 Dec 2019 10:19:57 +0800 Subject: [PATCH] [LITE][NPU][XPU] Add kernel context to NPU/XPU subgraph engine (#2686) --- lite/kernels/npu/bridges/engine.h | 7 +++++-- lite/kernels/npu/subgraph_compute.cc | 3 ++- lite/kernels/npu/subgraph_compute.h | 5 +++-- lite/kernels/xpu/subgraph_compute.cc | 3 ++- lite/kernels/xpu/subgraph_compute.h | 5 +++-- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/lite/kernels/npu/bridges/engine.h b/lite/kernels/npu/bridges/engine.h index db39063417..61a4e12cf3 100644 --- a/lite/kernels/npu/bridges/engine.h +++ b/lite/kernels/npu/bridges/engine.h @@ -28,12 +28,14 @@ namespace subgraph { class Engine { public: - Engine(int block_idx, + Engine(KernelContext *ctx, + int block_idx, cpp::BlockDesc *block_desc, const std::vector &input_names, const std::vector &output_names, lite::Scope *scope) - : block_idx_(block_idx), + : ctx_(ctx), + block_idx_(block_idx), block_desc_(block_desc), input_names_(input_names), output_names_(output_names), @@ -55,6 +57,7 @@ class Engine { virtual bool InputShapeChanged(); + KernelContext *ctx_{nullptr}; int block_idx_; cpp::BlockDesc *block_desc_; std::vector input_names_; diff --git a/lite/kernels/npu/subgraph_compute.cc b/lite/kernels/npu/subgraph_compute.cc index c6cbea46fa..d9b1919506 100644 --- a/lite/kernels/npu/subgraph_compute.cc +++ b/lite/kernels/npu/subgraph_compute.cc @@ -207,7 +207,8 @@ int SubgraphEngine::LaunchDeviceProgram() { void SubgraphCompute::PrepareForRun() { auto& param = this->Param(); - engine_.reset(new SubgraphEngine(param.sub_block_idx, + engine_.reset(new SubgraphEngine(ctx_.get(), + param.sub_block_idx, param.sub_block_desc, param.input_data_names, param.output_data_names, diff --git a/lite/kernels/npu/subgraph_compute.h b/lite/kernels/npu/subgraph_compute.h index dd0bf82bc9..27b4a36cfe 100644 --- a/lite/kernels/npu/subgraph_compute.h +++ b/lite/kernels/npu/subgraph_compute.h @@ -29,13 +29,14 @@ namespace npu { class SubgraphEngine : public subgraph::Engine { public: - SubgraphEngine(int block_idx, + SubgraphEngine(KernelContext *ctx, + int block_idx, cpp::BlockDesc *block_desc, const std::vector &input_names, const std::vector &output_names, Scope *scope) : subgraph::Engine( - block_idx, block_desc, input_names, output_names, scope) {} + ctx, block_idx, block_desc, input_names, output_names, scope) {} protected: int BuildDeviceProgram() override; diff --git a/lite/kernels/xpu/subgraph_compute.cc b/lite/kernels/xpu/subgraph_compute.cc index 0a7a4d2aa5..07a74b0454 100644 --- a/lite/kernels/xpu/subgraph_compute.cc +++ b/lite/kernels/xpu/subgraph_compute.cc @@ -197,7 +197,8 @@ int SubgraphEngine::LaunchDeviceProgram() { void SubgraphCompute::PrepareForRun() { auto& param = this->Param(); - engine_.reset(new SubgraphEngine(param.sub_block_idx, + engine_.reset(new SubgraphEngine(ctx_.get(), + param.sub_block_idx, param.sub_block_desc, param.input_data_names, param.output_data_names, diff --git a/lite/kernels/xpu/subgraph_compute.h b/lite/kernels/xpu/subgraph_compute.h index 2196eb3621..c21a1b7b05 100644 --- a/lite/kernels/xpu/subgraph_compute.h +++ b/lite/kernels/xpu/subgraph_compute.h @@ -29,13 +29,14 @@ namespace xpu { class SubgraphEngine : public subgraph::Engine { public: - SubgraphEngine(int block_idx, + SubgraphEngine(KernelContext *ctx, + int block_idx, cpp::BlockDesc *block_desc, const std::vector &input_names, const std::vector &output_names, Scope *scope) : subgraph::Engine( - block_idx, block_desc, input_names, output_names, scope) {} + ctx, block_idx, block_desc, input_names, output_names, scope) {} protected: int BuildDeviceProgram() override; -- GitLab