eager_tensor.h 4.7 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
// framework deps
17
#include "paddle/fluid/framework/phi_utils.h"
J
Jiabin Yang 已提交
18 19
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
20
// Phi deps
21 22 23
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"
J
Jiabin Yang 已提交
24 25 26 27 28 29 30 31 32
/**
 * This class is used by Eager mode for now. It's painful to do this in Eager
 * Mode, the better
 * choice is to use paddle::experimental::Tensor directly. However, we have a
 * punch of nested kernel code, and
 * they use paddle::framework::Variable in inner logic code. So, we have to
 * provide variable in
 * paddle::framework::ExecutionContext to support it. We should remove this as
 * soon as we finish our latest
33
 * Phi Lib, and use paddle::experimental::Tensor instead.
J
Jiabin Yang 已提交
34 35 36 37 38 39 40 41
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in
 * paddle::experimental::Tensor with access method of
 * paddle::framework::Variable no more members are acceptable.
 * **/

namespace egr {
42
class EagerVariable final {
J
Jiabin Yang 已提交
43
 public:
44 45
  /* Default constructor and name constructor should only be used for contruct
   * output and in fluid*/
46
  EagerVariable() = default;
47

48
  explicit EagerVariable(const std::string& name) : name_(name) {}
J
Jiabin Yang 已提交
49

50
  explicit EagerVariable(const paddle::experimental::Tensor& tensor)
51 52 53
      : name_(tensor.name()) {
    if (tensor.defined()) {
      if (tensor.is_dense_tensor()) {
54
        ConstructVariableFromTensor<phi::DenseTensor>(tensor);
55
      } else if (tensor.is_selected_rows()) {
56
        ConstructVariableFromTensor<phi::SelectedRows>(tensor);
J
Jiabin Yang 已提交
57 58
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
59
            "Unrecognized egr::EagerVariable type, only "
60
            "DenseTensor and SelectedRows are supported for now."));
J
Jiabin Yang 已提交
61
      }
62
    } else {
63
      VLOG(6) << "Build Empty EagerVariable with name " << name_;
J
Jiabin Yang 已提交
64 65
    }
  }
66

67 68
  /** Part 11: Construct paddle::framework::Variable with phi::Tensor **/
  std::shared_ptr<phi::TensorBase> GetTensorBase() {
69
    // Construct allocation only once.
70
    if (var_.IsInitialized()) {
71 72
      if (var_.IsType<paddle::framework::LoDTensor>() ||
          var_.IsType<paddle::framework::Tensor>()) {
73 74 75
        return SetImplWithLegacyTensor<phi::DenseTensor>();
      } else if (var_.IsType<phi::SelectedRows>()) {
        return SetImplWithLegacyTensor<phi::SelectedRows>();
J
Jiabin Yang 已提交
76
      } else {
77 78
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unable to fetch underlying tensor "
79
            "from EagerVariable, only LoDTensor and "
80
            "Tensor are supported for now"));
J
Jiabin Yang 已提交
81
      }
82 83
    } else {
      PADDLE_THROW(paddle::platform::errors::Fatal(
84
          "Can not Sync EagerVariable %s whose paddle::framework::Variable is "
85 86
          "not initialized!",
          name()));
J
Jiabin Yang 已提交
87 88
    }
  }
89 90 91
  const paddle::framework::Variable& Var() const { return var_; }

  paddle::framework::Variable* MutableVar() { return &var_; }
J
Jiabin Yang 已提交
92 93 94

  void ResetVar(const paddle::framework::Variable& src) { var_ = src; }

95
  const std::string& name() const { return name_; }
96

97
  void set_name(const std::string& name) { name_ = name; }
98

J
Jiabin Yang 已提交
99
 private:
J
Jiabin Yang 已提交
100
  template <typename VarType>
101
  std::shared_ptr<phi::TensorBase> SetImplWithLegacyTensor() {
J
Jiabin Yang 已提交
102
    const auto& framework_tensor = var_.Get<VarType>();
103
    VLOG(8) << "Sync Var to tensor for: " << name();
J
Jiabin Yang 已提交
104
    return std::make_shared<VarType>(framework_tensor);
105
  }
106

J
Jiabin Yang 已提交
107
  template <typename VarType>
108
  void ConstructVariableFromTensor(const paddle::experimental::Tensor& tensor) {
J
Jiabin Yang 已提交
109
    auto* framework_tensor = var_.GetMutable<VarType>();
110
    // Contruct framework::Tensor from egr::EagerVariable
J
Jiabin Yang 已提交
111
    auto tensor_dense = std::dynamic_pointer_cast<VarType>(tensor.impl());
112 113 114
    PADDLE_ENFORCE_EQ(
        (tensor_dense.get() && tensor_dense), true,
        paddle::platform::errors::Fatal(
115
            "Tensor %s does not hold phi::SelectedRows or phi::DenseTensor. "
116 117 118 119 120 121
            "Or it holds empty impl, this should not happend since we should "
            "treat all kinds of tensor as what they are.",
            tensor.name()));
    *framework_tensor = *tensor_dense;
  }

122 123
 private:
  std::string name_{""};
J
Jiabin Yang 已提交
124 125 126
  paddle::framework::Variable var_;
};
}  // namespace egr