// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/utils/any.h" #include "paddle/utils/optional.h" #include "paddle/utils/small_vector.h" namespace phi { /** * Note: KernelContext doesn't manage the life of DeviceContext and Tensor * * Note: KernelContext does not couple the concept of framework, * its constructor can only take the members it needs as parameters, * not Scope, RuntimeContext, etc. as parameters */ class KernelContext { public: KernelContext() = default; explicit KernelContext(DeviceContext* dev_ctx) : dev_ctx_(dev_ctx) {} void SetDeviceContext(DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } template const CtxType& GetDeviceContext() const { return static_cast(*dev_ctx_); } void EmplaceBackInput(const TensorBase* input); void EmplaceBackInputWithoutSetRange(const TensorBase* input); void EmplaceBackInputs(paddle::SmallVector inputs); void EmplaceBackInputsWithoutSetRange( paddle::SmallVector inputs); void EmplaceBackOutput(TensorBase* output); void EmplaceBackOutputWithoutSetRange(TensorBase* output); void EmplaceBackOutputs(paddle::SmallVector outputs); void EmplaceBackOutputsWithoutSetRange( paddle::SmallVector outputs); void EmplaceBackAttr(paddle::any attr); const std::pair& InputRangeAt(size_t idx) const; const std::pair& OutputRangeAt(size_t idx) const; void AssignInputRange(std::pair&& range, size_t idx); void AssignOutputRange(std::pair&& range, size_t idx); template const TensorType& InputAt(size_t idx) const { return static_cast(*(inputs_.at(idx))); } template paddle::optional OptionalInputAt(size_t idx) const { const auto& input = inputs_.at(idx); return input ? paddle::optional{static_cast< const TensorType&>(*input)} : paddle::optional{paddle::none}; } template std::vector InputsBetween(size_t start, size_t end) { std::vector v; for (size_t i = start; i < end; ++i) { auto* t = static_cast(inputs_.at(i)); v.emplace_back(t); } return v; } template paddle::optional> OptionalInputsBetween( size_t start, size_t end) { const auto& first = inputs_.at(start); if (first) { std::vector v; for (size_t i = start; i < end; ++i) { auto* t = static_cast(inputs_.at(i)); v.emplace_back(t); } return paddle::optional>(v); } return paddle::optional>(paddle::none); } template TensorType* MutableOutputAt(size_t idx) { return static_cast(outputs_.at(idx)); } template std::vector MutableOutputBetween(size_t start, size_t end) { std::vector v; for (size_t i = start; i < end; ++i) { v.emplace_back(static_cast(outputs_.at(i))); } return v; } template AttrType AttrAt(size_t idx) const { try { return paddle::any_cast(attrs_.at(idx)); } catch (paddle::bad_any_cast&) { PADDLE_THROW(phi::errors::InvalidArgument( "Attribute cast error in Op Kernel Context.")); } } size_t InputsSize() const { return inputs_.size(); } size_t OutputsSize() const { return outputs_.size(); } size_t AttrsSize() const { return attrs_.size(); } private: DeviceContext* dev_ctx_; paddle::SmallVector inputs_; paddle::SmallVector outputs_; paddle::SmallVector attrs_; paddle::SmallVector> input_range_; paddle::SmallVector> output_range_; }; } // namespace phi