// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { // An base with virtual functions to unify all the kernel implementation on // different targets. class KernelBase { public: virtual void Run() = 0; template void SetContext(std::unique_ptr>&& ctx) { context_.set>>(std::move(ctx)); } template void SetParam(T param) { param_.set(param); } template Param& param() const { return param_.get(); } void Torch() {} virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual ~KernelBase() = default; protected: core::any_context_t context_; mutable operators::param_t param_; }; // Light-weight kernel implementation. // The OpKernel is designed to implement the specific algorithm on a target // device. template class OpKernel : public KernelBase { public: // Set runtime context. void SetContext(std::unique_ptr&& ctx) { ctx_ = ctx; } // Run the kernel. virtual void Run() { CHECK(false) << "Not Implemented"; } TargetType target() const override { return Target; } PrecisionType precision() const override { return Precision; } OpKernel() = default; virtual ~OpKernel() = default; protected: std::unique_ptr ctx_; }; } // namespace lite } // namespace paddle