eager_tensor.h 4.8 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
// framework deps
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
// pten deps
21 22 23 24
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/api_declare.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"
J
Jiabin Yang 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/**
 * This class is used by Eager mode for now. It's painful to do this in Eager
 * Mode, the better
 * choice is to use paddle::experimental::Tensor directly. However, we have a
 * punch of nested kernel code, and
 * they use paddle::framework::Variable in inner logic code. So, we have to
 * provide variable in
 * paddle::framework::ExecutionContext to support it. We should remove this as
 * soon as we finish our latest
 * Pten Lib, and use paddle::experimental::Tensor instead.
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in
 * paddle::experimental::Tensor with access method of
 * paddle::framework::Variable no more members are acceptable.
 * **/

namespace egr {
43
class EagerVariable final {
J
Jiabin Yang 已提交
44
 public:
45 46
  /* Default constructor and name constructor should only be used for contruct
   * output and in fluid*/
47
  EagerVariable() = default;
48

49
  explicit EagerVariable(const std::string& name) : name_(name) {}
J
Jiabin Yang 已提交
50

51
  explicit EagerVariable(const paddle::experimental::Tensor& tensor)
52 53 54
      : name_(tensor.name()) {
    if (tensor.defined()) {
      if (tensor.is_dense_tensor()) {
55
        ConstructVariableFromTensor<phi::DenseTensor>(tensor);
56
      } else if (tensor.is_selected_rows()) {
57
        ConstructVariableFromTensor<phi::SelectedRows>(tensor);
J
Jiabin Yang 已提交
58 59
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
60
            "Unrecognized egr::EagerVariable type, only "
61
            "DenseTensor and SelectedRows are supported for now."));
J
Jiabin Yang 已提交
62
      }
63
    } else {
64
      VLOG(6) << "Build Empty EagerVariable with name " << name_;
J
Jiabin Yang 已提交
65 66
    }
  }
67

68 69
  /** Part 11: Construct paddle::framework::Variable with phi::Tensor **/
  std::shared_ptr<phi::TensorBase> GetTensorBase() {
70
    // Construct allocation only once.
71
    if (var_.IsInitialized()) {
72 73
      if (var_.IsType<paddle::framework::LoDTensor>() ||
          var_.IsType<paddle::framework::Tensor>()) {
74 75 76
        return SetImplWithLegacyTensor<phi::DenseTensor>();
      } else if (var_.IsType<phi::SelectedRows>()) {
        return SetImplWithLegacyTensor<phi::SelectedRows>();
J
Jiabin Yang 已提交
77
      } else {
78 79
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unable to fetch underlying tensor "
80
            "from EagerVariable, only LoDTensor and "
81
            "Tensor are supported for now"));
J
Jiabin Yang 已提交
82
      }
83 84
    } else {
      PADDLE_THROW(paddle::platform::errors::Fatal(
85
          "Can not Sync EagerVariable %s whose paddle::framework::Variable is "
86 87
          "not initialized!",
          name()));
J
Jiabin Yang 已提交
88 89
    }
  }
90 91 92
  const paddle::framework::Variable& Var() const { return var_; }

  paddle::framework::Variable* MutableVar() { return &var_; }
J
Jiabin Yang 已提交
93 94 95

  void ResetVar(const paddle::framework::Variable& src) { var_ = src; }

96
  const std::string& name() const { return name_; }
97

98
  void set_name(const std::string& name) { name_ = name; }
99

J
Jiabin Yang 已提交
100
 private:
J
Jiabin Yang 已提交
101
  template <typename VarType>
102
  std::shared_ptr<phi::TensorBase> SetImplWithLegacyTensor() {
J
Jiabin Yang 已提交
103
    const auto& framework_tensor = var_.Get<VarType>();
104
    VLOG(8) << "Sync Var to tensor for: " << name();
J
Jiabin Yang 已提交
105
    return std::make_shared<VarType>(framework_tensor);
106
  }
107

J
Jiabin Yang 已提交
108
  template <typename VarType>
109
  void ConstructVariableFromTensor(const paddle::experimental::Tensor& tensor) {
J
Jiabin Yang 已提交
110
    auto* framework_tensor = var_.GetMutable<VarType>();
111
    // Contruct framework::Tensor from egr::EagerVariable
J
Jiabin Yang 已提交
112
    auto tensor_dense = std::dynamic_pointer_cast<VarType>(tensor.impl());
113 114 115
    PADDLE_ENFORCE_EQ(
        (tensor_dense.get() && tensor_dense), true,
        paddle::platform::errors::Fatal(
116
            "Tensor %s does not hold phi::SelectedRows or phi::DenseTensor. "
117 118 119 120 121 122
            "Or it holds empty impl, this should not happend since we should "
            "treat all kinds of tensor as what they are.",
            tensor.name()));
    *framework_tensor = *tensor_dense;
  }

123 124
 private:
  std::string name_{""};
J
Jiabin Yang 已提交
125 126 127
  paddle::framework::Variable var_;
};
}  // namespace egr