From 22fae321ad9cf91eb103a447c7a7dbd72ffd1582 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Thu, 6 Jun 2019 13:28:56 +0800 Subject: [PATCH] add PrepareForRun to kernel (#17879) --- paddle/fluid/lite/core/kernel.h | 17 +++++++++++++++++ paddle/fluid/lite/core/op_lite.cc | 2 +- paddle/fluid/lite/core/program.h | 2 +- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 2eee83bd4..629da86bb 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -41,8 +41,24 @@ class KernelBase { const std::map& input_types, const std::string& out_arg)>; + protected: + /// Run some initialization before `Run`, it will invoke after `SetParam` and + /// `SetContext`, that is both the param_ and context_ are valid. + virtual void PrepareForRun() {} + + /// Run the kernel. Before Run, both the param_ and context_ should be valid. virtual void Run() = 0; + public: + void Launch() { + if (is_first_epoch_) { + PrepareForRun(); + is_first_epoch_ = false; + } + + Run(); + } + void SetContext(std::unique_ptr&& ctx) { ctx_ = std::move(ctx); } @@ -141,6 +157,7 @@ class KernelBase { // The extra identity to help defficiate a specific kernel, op_type_ + alias_ // is the unique ID for the kernel. std::string alias_{}; + bool is_first_epoch_{true}; }; // Light-weight kernel implementation. diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index bd98b23bf..dc22e4fb4 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -62,7 +62,7 @@ bool OpLite::Run() { CHECK(kernel_); SyncInputEvents(); - kernel_->Run(); + kernel_->Launch(); RecordOutputEvents(); return true; diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 1ebd6b437..234626a91 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -129,7 +129,7 @@ struct Instruct { CHECK(op_->CheckShape()); } op_->InferShape(); - kernel_->Run(); + kernel_->Launch(); } friend std::ostream& operator<<(std::ostream& os, const Instruct& other) { -- GitLab