kernel_context.h 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <iterator>
18 19
#include <utility>

20 21 22 23
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_utils.h"
24
#include "paddle/utils/any.h"
25
#include "paddle/utils/optional.h"
26 27
#include "paddle/utils/small_vector.h"

28
namespace phi {
29 30

/**
31
 * Note: KernelContext doesn't manage the life of DeviceContext and Tensor
32 33 34 35 36 37 38
 *
 * Note: KernelContext does not couple the concept of framework,
 *       its constructor can only take the members it needs as parameters,
 *       not Scope, RuntimeContext, etc. as parameters
 */
class KernelContext {
 public:
39 40 41 42
  KernelContext() = default;
  explicit KernelContext(DeviceContext* dev_ctx) : dev_ctx_(dev_ctx) {}

  void SetDeviceContext(DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
43 44 45

  template <typename CtxType>
  const CtxType& GetDeviceContext() const {
46
    return static_cast<const CtxType&>(*dev_ctx_);
47 48
  }

49
  void EmplaceBackInput(const TensorBase* input);
50

51
  void EmplaceBackInputWithoutSetRange(const TensorBase* input);
52

53
  void EmplaceBackInputs(paddle::SmallVector<const TensorBase*> inputs);
54

55 56 57
  void EmplaceBackInputsWithoutSetRange(
      paddle::SmallVector<const TensorBase*> inputs);

58
  void EmplaceBackOutput(TensorBase* output);
59

60
  void EmplaceBackOutputWithoutSetRange(TensorBase* output);
61

62
  void EmplaceBackOutputs(paddle::SmallVector<TensorBase*> outputs);
63

64 65 66
  void EmplaceBackOutputsWithoutSetRange(
      paddle::SmallVector<TensorBase*> outputs);

67 68 69 70 71 72
  void EmplaceBackAttr(paddle::any attr);

  const std::pair<int, int>& InputRangeAt(size_t idx) const;

  const std::pair<int, int>& OutputRangeAt(size_t idx) const;

73
  void AssignInputRange(std::pair<int, int>&& range, size_t idx);
74

75
  void AssignOutputRange(std::pair<int, int>&& range, size_t idx);
76 77 78 79 80 81

  template <typename TensorType>
  const TensorType& InputAt(size_t idx) const {
    return static_cast<const TensorType&>(*(inputs_.at(idx)));
  }

82 83 84 85 86 87 88 89
  template <typename TensorType>
  paddle::optional<const TensorType&> OptionalInputAt(size_t idx) const {
    const auto& input = inputs_.at(idx);
    return input ? paddle::optional<const TensorType&>{static_cast<
                       const TensorType&>(*input)}
                 : paddle::optional<const TensorType&>{paddle::none};
  }

90
  template <typename TensorType>
91 92
  std::vector<const TensorType*> InputsBetween(size_t start, size_t end) {
    std::vector<const TensorType*> v;
93
    for (size_t i = start; i < end; ++i) {
94 95
      auto* t = static_cast<const TensorType*>(inputs_.at(i));
      v.emplace_back(t);
96 97 98 99
    }
    return v;
  }

100 101
  template <typename TensorType>
  TensorType* MutableOutputAt(size_t idx) {
102
    return static_cast<TensorType*>(outputs_.at(idx));
103 104
  }

105 106 107 108
  template <typename TensorType>
  std::vector<TensorType*> MutableOutputBetween(size_t start, size_t end) {
    std::vector<TensorType*> v;
    for (size_t i = start; i < end; ++i) {
109
      v.emplace_back(static_cast<TensorType*>(outputs_.at(i)));
110 111 112 113
    }
    return v;
  }

114 115 116 117 118
  template <typename AttrType>
  AttrType AttrAt(size_t idx) const {
    try {
      return paddle::any_cast<AttrType>(attrs_.at(idx));
    } catch (paddle::bad_any_cast&) {
119
      PADDLE_THROW(phi::errors::InvalidArgument(
120 121 122 123
          "Attribute cast error in Op Kernel Context."));
    }
  }

124 125 126 127
  size_t InputsSize() const { return inputs_.size(); }
  size_t OutputsSize() const { return outputs_.size(); }
  size_t AttrsSize() const { return attrs_.size(); }

128
 private:
129
  DeviceContext* dev_ctx_;
130

131 132
  paddle::SmallVector<const TensorBase*> inputs_;
  paddle::SmallVector<TensorBase*> outputs_;
133 134 135 136 137 138
  paddle::SmallVector<paddle::any> attrs_;

  paddle::SmallVector<std::pair<int, int>> input_range_;
  paddle::SmallVector<std::pair<int, int>> output_range_;
};

139
}  // namespace phi