未验证 提交 1e10b471 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU][XPU] Add kernel context to NPU/XPU subgraph engine (#2686)

上级 ad1dfbf2
......@@ -28,12 +28,14 @@ namespace subgraph {
class Engine {
public:
Engine(int block_idx,
Engine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
lite::Scope *scope)
: block_idx_(block_idx),
: ctx_(ctx),
block_idx_(block_idx),
block_desc_(block_desc),
input_names_(input_names),
output_names_(output_names),
......@@ -55,6 +57,7 @@ class Engine {
virtual bool InputShapeChanged();
KernelContext *ctx_{nullptr};
int block_idx_;
cpp::BlockDesc *block_desc_;
std::vector<std::string> input_names_;
......
......@@ -207,7 +207,8 @@ int SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(param.sub_block_idx,
engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx,
param.sub_block_desc,
param.input_data_names,
param.output_data_names,
......
......@@ -29,13 +29,14 @@ namespace npu {
class SubgraphEngine : public subgraph::Engine {
public:
SubgraphEngine(int block_idx,
SubgraphEngine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
Scope *scope)
: subgraph::Engine(
block_idx, block_desc, input_names, output_names, scope) {}
ctx, block_idx, block_desc, input_names, output_names, scope) {}
protected:
int BuildDeviceProgram() override;
......
......@@ -197,7 +197,8 @@ int SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(param.sub_block_idx,
engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx,
param.sub_block_desc,
param.input_data_names,
param.output_data_names,
......
......@@ -29,13 +29,14 @@ namespace xpu {
class SubgraphEngine : public subgraph::Engine {
public:
SubgraphEngine(int block_idx,
SubgraphEngine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
Scope *scope)
: subgraph::Engine(
block_idx, block_desc, input_names, output_names, scope) {}
ctx, block_idx, block_desc, input_names, output_names, scope) {}
protected:
int BuildDeviceProgram() override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册