eager_tensor.h 5.1 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
// framework deps
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
// pten deps
J
Jiabin Yang 已提交
21
#include "paddle/pten/api/include/tensor.h"
22
#include "paddle/pten/api/lib/api_declare.h"
J
Jiabin Yang 已提交
23
#include "paddle/pten/api/lib/utils/tensor_utils.h"
24
#include "paddle/pten/core/compat/convert_utils.h"
J
Jiabin Yang 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
/**
 * This class is used by Eager mode for now. It's painful to do this in Eager
 * Mode, the better
 * choice is to use paddle::experimental::Tensor directly. However, we have a
 * punch of nested kernel code, and
 * they use paddle::framework::Variable in inner logic code. So, we have to
 * provide variable in
 * paddle::framework::ExecutionContext to support it. We should remove this as
 * soon as we finish our latest
 * Pten Lib, and use paddle::experimental::Tensor instead.
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in
 * paddle::experimental::Tensor with access method of
 * paddle::framework::Variable no more members are acceptable.
 * **/

namespace egr {
class EagerTensor final {
 public:
45 46 47
  /* Default constructor and name constructor should only be used for contruct
   * output and in fluid*/
  EagerTensor() = default;
48

49
  explicit EagerTensor(const std::string& name) : name_(name) {}
J
Jiabin Yang 已提交
50

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
  explicit EagerTensor(const paddle::experimental::Tensor& tensor)
      : name_(tensor.name()) {
    if (tensor.defined()) {
      if (tensor.is_dense_tensor()) {
        auto* framework_tensor =
            var_.GetMutable<paddle::framework::LoDTensor>();
        // Contruct framework::Tensor from egr::EagerTensor
        auto tensor_dense =
            std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl());
        PADDLE_ENFORCE_EQ((tensor_dense.get() && tensor_dense), true,
                          paddle::platform::errors::Fatal(
                              "Failed to Trans Tensor to EagerVariable since "
                              "we got Tensor with type DenseTensor, and we got "
                              "EagerVariable with another type."));
        *framework_tensor = *tensor_dense;
J
Jiabin Yang 已提交
66 67
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
68 69
            "Unrecognized egr::EagerVariable type, only "
            "DenseTensor and SelectedRows is supported for now."));
J
Jiabin Yang 已提交
70
      }
71 72
    } else {
      VLOG(6) << "Build Empty EagerTensor with name " << name_;
J
Jiabin Yang 已提交
73 74
    }
  }
75 76 77 78

  /** Part 11: Construct paddle::framework::Variable with pten::Tensor **/
  std::shared_ptr<pten::TensorBase> GetTensorBase() {
    // Construct allocation only once.
79 80
    if (var_.IsInitialized()) {
      if (var_.IsType<paddle::framework::LoDTensor>()) {
81
        return SetImplWithLegacyTensor<pten::DenseTensor>();
82
      } else if (var_.IsType<paddle::framework::Tensor>()) {
83 84 85
        return SetImplWithLegacyTensor<pten::DenseTensor>();
      } else if (var_.IsType<pten::SelectedRows>()) {
        return SetImplWithSelectedRows();
J
Jiabin Yang 已提交
86
      } else {
87 88 89 90
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unable to fetch underlying tensor "
            "from EagerTensor, only LoDTensor and "
            "Tensor are supported for now"));
J
Jiabin Yang 已提交
91
      }
92 93 94 95 96
    } else {
      PADDLE_THROW(paddle::platform::errors::Fatal(
          "Can not Sync EagerTensor %s whose paddle::framework::Variable is "
          "not initialized!",
          name()));
J
Jiabin Yang 已提交
97 98
    }
  }
99 100 101
  const paddle::framework::Variable& Var() const { return var_; }

  paddle::framework::Variable* MutableVar() { return &var_; }
J
Jiabin Yang 已提交
102 103 104

  void ResetVar(const paddle::framework::Variable& src) { var_ = src; }

105
  const std::string& name() const { return name_; }
106

107
  void set_name(const std::string& name) { name_ = name; }
108

J
Jiabin Yang 已提交
109
 private:
110 111
  template <typename LEGACY_TYPE>
  std::shared_ptr<pten::TensorBase> SetImplWithLegacyTensor() {
J
Jiabin Yang 已提交
112
    const auto& framework_tensor = var_.Get<LEGACY_TYPE>();
113 114
    VLOG(8) << "Sync Var to tensor for: " << name();
    return std::make_shared<LEGACY_TYPE>(std::move(framework_tensor));
J
Jiabin Yang 已提交
115 116
  }

117 118 119 120 121 122 123 124 125
  std::shared_ptr<pten::TensorBase> SetImplWithSelectedRows() {
    auto* selected_rows = var_.GetMutable<pten::SelectedRows>();
    auto res = std::make_shared<pten::SelectedRows>(selected_rows->rows_,
                                                    selected_rows->height_);
    res->value_.reset(selected_rows->value_.release());
    res->id_to_index_ = std::move(selected_rows->id_to_index_);
    res->rwlock_.reset(selected_rows->rwlock_.release());
    return res;
  }
126

127 128
 private:
  std::string name_{""};
J
Jiabin Yang 已提交
129 130 131
  paddle::framework::Variable var_;
};
}  // namespace egr