eager_tensor.h 5.8 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
// framework deps
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
// pten deps
J
Jiabin Yang 已提交
21
#include "paddle/pten/api/include/tensor.h"
22
#include "paddle/pten/api/lib/api_declare.h"
J
Jiabin Yang 已提交
23
#include "paddle/pten/api/lib/utils/tensor_utils.h"
24
#include "paddle/pten/core/compat/convert_utils.h"
J
Jiabin Yang 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/**
 * This class is used by Eager mode for now. It's painful to do this in Eager
 * Mode, the better
 * choice is to use paddle::experimental::Tensor directly. However, we have a
 * punch of nested kernel code, and
 * they use paddle::framework::Variable in inner logic code. So, we have to
 * provide variable in
 * paddle::framework::ExecutionContext to support it. We should remove this as
 * soon as we finish our latest
 * Pten Lib, and use paddle::experimental::Tensor instead.
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in
 * paddle::experimental::Tensor with access method of
 * paddle::framework::Variable no more members are acceptable.
 * **/

namespace egr {
43
class EagerVariable final {
J
Jiabin Yang 已提交
44
 public:
45 46
  /* Default constructor and name constructor should only be used for contruct
   * output and in fluid*/
47
  EagerVariable() = default;
48

49
  explicit EagerVariable(const std::string& name) : name_(name) {}
J
Jiabin Yang 已提交
50

51
  explicit EagerVariable(const paddle::experimental::Tensor& tensor)
52 53 54
      : name_(tensor.name()) {
    if (tensor.defined()) {
      if (tensor.is_dense_tensor()) {
55 56 57
        ConstructVariableFromTensor(tensor);
      } else if (tensor.is_selected_rows()) {
        ConstructVariableFromSelectedRows(tensor);
J
Jiabin Yang 已提交
58 59
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
60
            "Unrecognized egr::EagerVariable type, only "
61
            "DenseTensor and SelectedRows are supported for now."));
J
Jiabin Yang 已提交
62
      }
63
    } else {
64
      VLOG(6) << "Build Empty EagerVariable with name " << name_;
J
Jiabin Yang 已提交
65 66
    }
  }
67 68 69 70

  /** Part 11: Construct paddle::framework::Variable with pten::Tensor **/
  std::shared_ptr<pten::TensorBase> GetTensorBase() {
    // Construct allocation only once.
71
    if (var_.IsInitialized()) {
72 73 74
      if (var_.IsType<paddle::framework::LoDTensor>() ||
          var_.IsType<paddle::framework::Tensor>()) {
        return SetImplWithLegacyTensor();
75
      } else if (var_.IsType<pten::SelectedRows>()) {
76
        return SetImplWithLegacySelectedRows();
J
Jiabin Yang 已提交
77
      } else {
78 79
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unable to fetch underlying tensor "
80
            "from EagerVariable, only LoDTensor and "
81
            "Tensor are supported for now"));
J
Jiabin Yang 已提交
82
      }
83 84
    } else {
      PADDLE_THROW(paddle::platform::errors::Fatal(
85
          "Can not Sync EagerVariable %s whose paddle::framework::Variable is "
86 87
          "not initialized!",
          name()));
J
Jiabin Yang 已提交
88 89
    }
  }
90 91 92
  const paddle::framework::Variable& Var() const { return var_; }

  paddle::framework::Variable* MutableVar() { return &var_; }
J
Jiabin Yang 已提交
93 94 95

  void ResetVar(const paddle::framework::Variable& src) { var_ = src; }

96
  const std::string& name() const { return name_; }
97

98
  void set_name(const std::string& name) { name_ = name; }
99

J
Jiabin Yang 已提交
100
 private:
101
  std::shared_ptr<pten::TensorBase> SetImplWithLegacyTensor() {
102
    const auto& framework_tensor = var_.Get<pten::DenseTensor>();
103
    VLOG(8) << "Sync Var to tensor for: " << name();
104
    return std::make_shared<pten::DenseTensor>(framework_tensor);
J
Jiabin Yang 已提交
105 106
  }

107 108 109 110 111 112
  std::shared_ptr<pten::TensorBase> SetImplWithLegacySelectedRows() {
    auto* framework_tensor = var_.GetMutable<pten::SelectedRows>();
    VLOG(8) << "Sync SelectedRows to tensor for: " << name();
    auto res =
        std::make_shared<pten::SelectedRows>(std::move(*framework_tensor));
    var_.Clear();
113 114
    return res;
  }
115

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  void ConstructVariableFromTensor(const paddle::experimental::Tensor& tensor) {
    auto* framework_tensor = var_.GetMutable<pten::DenseTensor>();
    // Contruct framework::Tensor from egr::EagerVariable
    auto tensor_dense =
        std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl());
    PADDLE_ENFORCE_EQ(
        (tensor_dense.get() && tensor_dense), true,
        paddle::platform::errors::Fatal(
            "Tensor %s does not hold pten::SelectedRows or pten::DenseTensor. "
            "Or it holds empty impl, this should not happend since we should "
            "treat all kinds of tensor as what they are.",
            tensor.name()));
    *framework_tensor = *tensor_dense;
  }

  void ConstructVariableFromSelectedRows(
      const paddle::experimental::Tensor& tensor) {
    auto* framework_tensor = var_.GetMutable<pten::SelectedRows>();
    // Contruct framework::Tensor from egr::EagerVariable
    auto tensor_dense =
        std::dynamic_pointer_cast<pten::SelectedRows>(tensor.impl());
    PADDLE_ENFORCE_EQ(
        (tensor_dense.get() && tensor_dense), true,
        paddle::platform::errors::Fatal(
            "Tensor %s does not hold pten::SelectedRows or pten::DenseTensor. "
            "Or it holds empty impl, this should not happend since we should "
            "treat all kinds of tensor as what they are.",
            tensor.name()));
    *framework_tensor = std::move(*tensor_dense);
  }

147 148
 private:
  std::string name_{""};
J
Jiabin Yang 已提交
149 150 151
  paddle::framework::Variable var_;
};
}  // namespace egr