kernel_context.h 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <iterator>
18 19
#include <utility>

20
#include "paddle/phi/core/attribute.h"
21 22 23 24
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_utils.h"
25
#include "paddle/phi/core/type_defs.h"
26
#include "paddle/utils/optional.h"
27 28
#include "paddle/utils/small_vector.h"

29
namespace phi {
30 31

/**
32
 * Note: KernelContext doesn't manage the life of DeviceContext and Tensor
33 34 35 36 37 38 39
 *
 * Note: KernelContext does not couple the concept of framework,
 *       its constructor can only take the members it needs as parameters,
 *       not Scope, RuntimeContext, etc. as parameters
 */
class KernelContext {
 public:
40 41 42 43
  KernelContext() = default;
  explicit KernelContext(DeviceContext* dev_ctx) : dev_ctx_(dev_ctx) {}

  void SetDeviceContext(DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
44 45 46

  template <typename CtxType>
  const CtxType& GetDeviceContext() const {
47
    return static_cast<const CtxType&>(*dev_ctx_);
48 49
  }

50
  void EmplaceBackInput(const TensorBase* input);
51

52
  void EmplaceBackInputWithoutSetRange(const TensorBase* input);
53

C
Chen Weihang 已提交
54
  void EmplaceBackInputs(paddle::small_vector<const TensorBase*> inputs);
55

56
  void EmplaceBackInputsWithoutSetRange(
C
Chen Weihang 已提交
57
      paddle::small_vector<const TensorBase*> inputs);
58

59
  void EmplaceBackOutput(TensorBase* output);
60

61
  void EmplaceBackOutputWithoutSetRange(TensorBase* output);
62

C
Chen Weihang 已提交
63
  void EmplaceBackOutputs(paddle::small_vector<TensorBase*> outputs);
64

65
  void EmplaceBackOutputsWithoutSetRange(
C
Chen Weihang 已提交
66
      paddle::small_vector<TensorBase*> outputs);
67

68
  void EmplaceBackAttr(Attribute attr);
69 70 71 72 73

  const std::pair<int, int>& InputRangeAt(size_t idx) const;

  const std::pair<int, int>& OutputRangeAt(size_t idx) const;

74
  void AssignInputRange(std::pair<int, int>&& range, size_t idx);
75

76
  void AssignOutputRange(std::pair<int, int>&& range, size_t idx);
77 78 79 80 81 82

  template <typename TensorType>
  const TensorType& InputAt(size_t idx) const {
    return static_cast<const TensorType&>(*(inputs_.at(idx)));
  }

83
  template <typename TensorType>
84 85 86 87 88
  paddle::optional<TensorType> OptionalInputAt(size_t idx) const {
    const auto* input = inputs_.at(idx);
    return input ? paddle::make_optional<TensorType>(
                       *(static_cast<const TensorType*>(input)))
                 : paddle::none;
89 90
  }

91
  template <typename TensorType>
92 93
  std::vector<const TensorType*> InputsBetween(size_t start, size_t end) {
    std::vector<const TensorType*> v;
94
    for (size_t i = start; i < end; ++i) {
95 96
      auto* t = static_cast<const TensorType*>(inputs_.at(i));
      v.emplace_back(t);
97 98 99 100
    }
    return v;
  }

101
  template <typename TensorType>
102
  paddle::optional<std::vector<const TensorType*>> OptionalInputsBetween(
103 104 105 106 107 108 109 110 111
      size_t start, size_t end) {
    const auto& first = inputs_.at(start);

    if (first) {
      std::vector<const TensorType*> v;
      for (size_t i = start; i < end; ++i) {
        auto* t = static_cast<const TensorType*>(inputs_.at(i));
        v.emplace_back(t);
      }
112
      return paddle::optional<std::vector<const TensorType*>>(std::move(v));
113
    }
114
    return paddle::none;
115 116
  }

117 118
  template <typename TensorType>
  TensorType* MutableOutputAt(size_t idx) {
119
    return static_cast<TensorType*>(outputs_.at(idx));
120 121
  }

122 123
  TensorBase* MutableOutputAt(size_t idx) { return outputs_.at(idx); }

124 125 126
  template <typename TensorType>
  std::vector<TensorType*> MutableOutputBetween(size_t start, size_t end) {
    std::vector<TensorType*> v;
127
    bool is_empty_vector = true;
128
    for (size_t i = start; i < end; ++i) {
129
      v.emplace_back(static_cast<TensorType*>(outputs_.at(i)));
130 131 132 133 134 135
      if (outputs_.at(i) != nullptr) {
        is_empty_vector = false;
      }
    }
    if (is_empty_vector) {
      v.clear();
136 137 138 139
    }
    return v;
  }

140
  template <typename AttrType>
141
  const AttrType& AttrAt(size_t idx) const;
142

143 144 145 146
  size_t InputsSize() const { return inputs_.size(); }
  size_t OutputsSize() const { return outputs_.size(); }
  size_t AttrsSize() const { return attrs_.size(); }

147 148 149 150 151 152 153
  void ClearInputOutput() {
    inputs_.clear();
    input_range_.clear();
    outputs_.clear();
    output_range_.clear();
  }

154
 private:
155
  DeviceContext* dev_ctx_;
156

C
Chen Weihang 已提交
157 158 159
  paddle::small_vector<const TensorBase*> inputs_;
  paddle::small_vector<TensorBase*> outputs_;
  paddle::small_vector<Attribute, kAttrSmallVectorSize> attrs_;
160

C
Chen Weihang 已提交
161 162
  paddle::small_vector<std::pair<int, int>, kInputSmallVectorSize> input_range_;
  paddle::small_vector<std::pair<int, int>, kOutputSmallVectorSize>
163
      output_range_;
164 165
};

166
}  // namespace phi