kernel_context.h 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <iterator>
18 19
#include <utility>

20
#include "paddle/pten/core/compat_utils.h"
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#include "paddle/pten/core/tensor_base.h"
#include "paddle/utils/any.h"
#include "paddle/utils/small_vector.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"

namespace pten {

using DeviceContext = paddle::platform::DeviceContext;
using DataType = paddle::experimental::DataType;
using DataLayout = paddle::experimental::DataLayout;

/**
 * Note: KernelContext doesn't manage the life if DeviceContext and Tensor
 *
 * Note: KernelContext does not couple the concept of framework,
 *       its constructor can only take the members it needs as parameters,
 *       not Scope, RuntimeContext, etc. as parameters
 */
class KernelContext {
 public:
44 45 46 47
  KernelContext() = default;
  explicit KernelContext(DeviceContext* dev_ctx) : dev_ctx_(dev_ctx) {}

  void SetDeviceContext(DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
48 49 50

  template <typename CtxType>
  const CtxType& GetDeviceContext() const {
51
    return static_cast<const CtxType&>(*dev_ctx_);
52 53 54
  }

  void EmplaceBackInput(std::shared_ptr<TensorBase> input) {
55
    int index = inputs_.size();
56 57 58 59 60
    inputs_.emplace_back(std::move(input));
    // Record the start and end index of the input
    input_range_.emplace_back(std::pair<int, int>(index, index + 1));
  }

61 62 63 64
  void EmplaceBackInputWithoutSetRange(std::shared_ptr<TensorBase> input) {
    inputs_.emplace_back(std::move(input));
  }

65
  void EmplaceBackInputs(
66
      paddle::SmallVector<std::shared_ptr<TensorBase>> inputs) {
67
    int index = inputs_.size();
68 69 70
    // Record the start and end index of the input
    input_range_.emplace_back(
        std::pair<int, int>(index, index + inputs.size()));
71 72 73
    inputs_.insert(inputs_.end(),
                   std::make_move_iterator(inputs.begin()),
                   std::make_move_iterator(inputs.end()));
74 75 76
  }

  void EmplaceBackOutput(std::shared_ptr<TensorBase> output) {
77
    int index = outputs_.size();
78 79 80 81 82
    outputs_.emplace_back(std::move(output));
    // Record the start and end index of the input
    output_range_.emplace_back(std::pair<int, int>(index, index + 1));
  }

83 84 85 86
  void EmplaceBackOutputWithoutSetRange(std::shared_ptr<TensorBase> output) {
    outputs_.emplace_back(std::move(output));
  }

87
  void EmplaceBackOutputs(
88
      paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) {
89
    int index = outputs_.size();
90 91 92
    // Record the start and end index of the input
    output_range_.emplace_back(
        std::pair<int, int>(index, index + outputs.size()));
93 94 95
    outputs_.insert(outputs_.end(),
                    std::make_move_iterator(outputs.begin()),
                    std::make_move_iterator(outputs.end()));
96 97 98 99 100 101 102 103 104 105 106
  }

  void EmplaceBackAttr(paddle::any attr) {
    attrs_.emplace_back(std::move(attr));
  }

  template <typename TensorType>
  const TensorType& InputAt(size_t idx) const {
    return static_cast<const TensorType&>(*(inputs_.at(idx)));
  }

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  template <typename TensorType>
  std::vector<TensorType> InputBetween(size_t start, size_t end) const {
    std::vector<TensorType> v;
    for (size_t i = start; i < end; ++i) {
      auto t = std::dynamic_pointer_cast<TensorType>(inputs_.at(i));
      v.emplace_back(std::move(*t.get()));
    }

    return v;
  }

  const std::pair<int, int>& InputRangeAt(size_t idx) const {
    return input_range_.at(idx);
  }

  const std::pair<int, int>& OutputRangeAt(size_t idx) const {
    return output_range_.at(idx);
  }

126 127 128 129 130 131 132 133 134 135 136 137 138
  std::pair<int, int>& MutableInputRangeAt(size_t idx) {
    return input_range_[idx];
  }

  std::pair<int, int>& MutableOutputRangeAt(size_t idx) {
    return output_range_[idx];
  }

  template <typename TensorType>
  TensorType* MutableInputAt(size_t idx) {
    return static_cast<TensorType*>(inputs_.at(idx).get());
  }

139 140 141 142 143
  template <typename TensorType>
  TensorType* MutableOutputAt(size_t idx) {
    return static_cast<TensorType*>(outputs_.at(idx).get());
  }

144 145 146 147 148 149 150 151 152 153
  template <typename TensorType>
  std::vector<TensorType*> MutableOutputBetween(size_t start, size_t end) {
    std::vector<TensorType*> v;
    for (size_t i = start; i < end; ++i) {
      v.emplace_back(static_cast<TensorType*>(outputs_.at(i).get()));
    }

    return v;
  }

154 155 156 157 158 159 160 161 162 163
  template <typename AttrType>
  AttrType AttrAt(size_t idx) const {
    try {
      return paddle::any_cast<AttrType>(attrs_.at(idx));
    } catch (paddle::bad_any_cast&) {
      PADDLE_THROW(paddle::platform::errors::InvalidArgument(
          "Attribute cast error in Op Kernel Context."));
    }
  }

164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
  // Temporary method: For compatible with fluid Tensor and improve performance
  // Only deal with DenseTensor now
  void ClearData() {
    for (auto& in : inputs_) {
      CompatibleDenseTensorUtils::ClearStorage(
          static_cast<DenseTensor*>(in.get()));
    }
    for (auto& out : outputs_) {
      CompatibleDenseTensorUtils::ClearStorage(
          static_cast<DenseTensor*>(out.get()));
    }
    attrs_.clear();
  }

  size_t InputsSize() const { return inputs_.size(); }
  size_t OutputsSize() const { return outputs_.size(); }
  size_t AttrsSize() const { return attrs_.size(); }

182 183
 private:
  // DeviceContext base class
184
  DeviceContext* dev_ctx_;
185 186 187 188 189 190 191 192 193 194 195 196 197

  // TODO(chenweihang): Tensor -> Tensor*, Tensor should by managed `scope`
  // Note: can't use API Tensor here, the inference don't use this API Tensor
  paddle::SmallVector<std::shared_ptr<TensorBase>> inputs_;
  paddle::SmallVector<std::shared_ptr<TensorBase>> outputs_;
  paddle::SmallVector<paddle::any> attrs_;

  // Only contains input like list[Tensor] need `range`
  paddle::SmallVector<std::pair<int, int>> input_range_;
  paddle::SmallVector<std::pair<int, int>> output_range_;
};

}  // namespace pten