未验证 提交 116c5118 编写于 作者: Y Yan Chunwei 提交者: GitHub

add PrepareForRun to kernel (#17879)

上级 8c00498e
...@@ -41,8 +41,24 @@ class KernelBase { ...@@ -41,8 +41,24 @@ class KernelBase {
const std::map<std::string, const Type*>& input_types, const std::map<std::string, const Type*>& input_types,
const std::string& out_arg)>; const std::string& out_arg)>;
protected:
/// Run some initialization before `Run`, it will invoke after `SetParam` and
/// `SetContext`, that is both the param_ and context_ are valid.
virtual void PrepareForRun() {}
/// Run the kernel. Before Run, both the param_ and context_ should be valid.
virtual void Run() = 0; virtual void Run() = 0;
public:
void Launch() {
if (is_first_epoch_) {
PrepareForRun();
is_first_epoch_ = false;
}
Run();
}
void SetContext(std::unique_ptr<KernelContext>&& ctx) { void SetContext(std::unique_ptr<KernelContext>&& ctx) {
ctx_ = std::move(ctx); ctx_ = std::move(ctx);
} }
...@@ -141,6 +157,7 @@ class KernelBase { ...@@ -141,6 +157,7 @@ class KernelBase {
// The extra identity to help defficiate a specific kernel, op_type_ + alias_ // The extra identity to help defficiate a specific kernel, op_type_ + alias_
// is the unique ID for the kernel. // is the unique ID for the kernel.
std::string alias_{}; std::string alias_{};
bool is_first_epoch_{true};
}; };
// Light-weight kernel implementation. // Light-weight kernel implementation.
......
...@@ -62,7 +62,7 @@ bool OpLite::Run() { ...@@ -62,7 +62,7 @@ bool OpLite::Run() {
CHECK(kernel_); CHECK(kernel_);
SyncInputEvents(); SyncInputEvents();
kernel_->Run(); kernel_->Launch();
RecordOutputEvents(); RecordOutputEvents();
return true; return true;
......
...@@ -129,7 +129,7 @@ struct Instruct { ...@@ -129,7 +129,7 @@ struct Instruct {
CHECK(op_->CheckShape()); CHECK(op_->CheckShape());
} }
op_->InferShape(); op_->InferShape();
kernel_->Run(); kernel_->Launch();
} }
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) { friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册