execution_context.h 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>
19

20 21 22 23
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/type_defs.h"
J
Jiabin Yang 已提交
24
#include "paddle/fluid/imperative/var_helper.h"
25 26 27 28 29 30 31 32 33 34 35 36 37

namespace paddle {
namespace imperative {

template <typename VarType>
class DygraphExecutionContext : public framework::ExecutionContext {
  using Variable = framework::Variable;

 public:
  DygraphExecutionContext(const framework::OperatorBase& op,
                          const framework::Scope& scope,
                          const platform::DeviceContext& device_context,
                          const framework::RuntimeContext& ctx,
J
Jiabin Yang 已提交
38 39
                          const NameVarMap<VarType>& var_map_in,
                          const NameVarMap<VarType>& var_map_out,
40 41
                          const framework::AttributeMap& attrs,
                          const framework::AttributeMap& default_attrs)
42
      : ExecutionContext(op, scope, device_context, ctx),
J
Jiabin Yang 已提交
43 44
        var_map_in_(var_map_in),
        var_map_out_(var_map_out),
45 46
        attrs_(attrs),
        default_attrs_(default_attrs) {}
47 48

  std::string InputName(const std::string& name) const override {
J
Jiabin Yang 已提交
49
    auto it = var_map_in_.find(name);
50 51
    PADDLE_ENFORCE_NE(it,
                      var_map_in_.end(),
52 53
                      platform::errors::PreconditionNotMet(
                          "Can not find [%s] in Input", name));
J
Jiabin Yang 已提交
54 55
    return it->second[0] ? GetNameFromVar(it->second[0])
                         : framework::kEmptyVarName;
56 57 58
  }

  std::vector<std::string> InputNames(const std::string& name) const override {
J
Jiabin Yang 已提交
59
    auto it = var_map_in_.find(name);
60
    PADDLE_ENFORCE_NE(
61 62
        it,
        var_map_in_.end(),
63 64 65 66 67
        platform::errors::NotFound("Can not find [%s] in Input", name));
    std::vector<std::string> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      if (it->second[i]) {
J
Jiabin Yang 已提交
68
        vec_res.push_back(GetNameFromVar(it->second[i]));
69 70 71 72 73 74 75 76
      } else {
        vec_res.push_back(framework::kEmptyVarName);
      }
    }
    return vec_res;
  }

  std::string OutputName(const std::string& name) const override {
J
Jiabin Yang 已提交
77
    auto it = var_map_out_.find(name);
78
    PADDLE_ENFORCE_NE(
79 80
        it,
        var_map_out_.end(),
81
        platform::errors::NotFound("Can not find [%s] in Output", name));
J
Jiabin Yang 已提交
82 83
    return it->second[0] ? GetNameFromVar(it->second[0])
                         : framework::kEmptyVarName;
84 85 86
  }

  std::vector<std::string> OutputNames(const std::string& name) const override {
J
Jiabin Yang 已提交
87
    auto it = var_map_out_.find(name);
88
    PADDLE_ENFORCE_NE(
89 90
        it,
        var_map_out_.end(),
91 92 93 94 95
        platform::errors::NotFound("Can not find [%s] in Output", name));
    std::vector<std::string> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      if (it->second[i]) {
J
Jiabin Yang 已提交
96
        vec_res.push_back(GetNameFromVar(it->second[i]));
97 98 99 100 101 102 103 104
      } else {
        vec_res.push_back(framework::kEmptyVarName);
      }
    }
    return vec_res;
  }

  bool HasAttr(const std::string& name) const override {
105
    if (attrs_.find(name) == attrs_.end()) {
106 107
      return &default_attrs_ != nullptr &&
             default_attrs_.find(name) != default_attrs_.end();
108 109
    }
    return true;
110 111 112 113 114 115 116
  }

  const framework::AttributeMap& Attrs() const override { return attrs_; }

  const framework::Attribute& GetAttr(const std::string& name) const override {
    auto it = attrs_.find(name);

117 118 119 120
    if (it == attrs_.end()) {
      it = default_attrs_.find(name);
      if (it == default_attrs_.end()) {
        PADDLE_THROW(platform::errors::NotFound(
121 122
            "Can not find [%s] in attributes of op %s.",
            name,
123 124 125
            this->GetOp().Type()));
      }
    }
126 127 128 129

    return it->second;
  }

C
Chen Weihang 已提交
130 131
  paddle::small_vector<const std::string*> InNameList() const override {
    paddle::small_vector<const std::string*> vec_temp;
J
Jiabin Yang 已提交
132
    vec_temp.reserve(var_map_in_.size());
133

J
Jiabin Yang 已提交
134
    for (auto& v : var_map_in_) {
135
      vec_temp.push_back(&v.first);
136 137 138 139 140 141
    }

    return vec_temp;
  }

  bool HasInput(const std::string& name) const override {
J
Jiabin Yang 已提交
142 143
    auto it = var_map_in_.find(name);
    return (it != var_map_in_.end() && it->second.size() > 0);
144 145
  }

146 147 148 149 150
  bool HasInputs(const std::string& name) const override {
    auto it = var_map_in_.find(name);
    return (it != var_map_in_.end() && it->second.size() > 0);
  }

151
  bool HasOutput(const std::string& name) const override {
J
Jiabin Yang 已提交
152 153
    auto it = var_map_out_.find(name);
    return (it != var_map_out_.end() && it->second.size() > 0);
154 155 156
  }

  size_t InputSize(const std::string& name) const override {
157 158
    auto it = var_map_in_.find(name);
    PADDLE_ENFORCE_NE(
159 160
        it,
        var_map_in_.end(),
161 162
        platform::errors::NotFound("Can not find [%s] in Input", name));
    return it->second.size();
163 164 165
  }

  size_t OutputSize(const std::string& name) const override {
166 167
    auto it = var_map_out_.find(name);
    PADDLE_ENFORCE_NE(
168 169
        it,
        var_map_out_.end(),
170 171
        platform::errors::NotFound("Can not find [%s] in Output", name));
    return it->second.size();
172 173 174
  }

  const Variable* InputVar(const std::string& name) const override {
J
Jiabin Yang 已提交
175 176
    auto it = var_map_in_.find(name);
    if (it == var_map_in_.end()) {
177 178 179 180 181 182 183 184 185
      return nullptr;
    }

    return it->second.empty() || it->second[0] == nullptr
               ? nullptr
               : it->second[0]->MutableVar();
  }

  Variable* OutputVar(const std::string& name) const override {
J
Jiabin Yang 已提交
186 187
    auto it = var_map_out_.find(name);
    if (it == var_map_out_.end()) {
188 189 190 191 192 193 194 195 196 197
      return nullptr;
    }

    return it->second.empty() || it->second[0] == nullptr
               ? nullptr
               : it->second[0]->MutableVar();
  }

  const std::vector<Variable*> MultiInputVar(
      const std::string& name) const override {
J
Jiabin Yang 已提交
198 199
    auto it = var_map_in_.find(name);
    if (it == var_map_in_.end()) {
200 201 202 203 204 205 206 207 208 209 210 211 212
      return {};
    }
    std::vector<Variable*> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
    }

    return vec_res;
  }

  std::vector<Variable*> MultiOutputVar(
      const std::string& name) const override {
J
Jiabin Yang 已提交
213 214
    auto it = var_map_out_.find(name);
    if (it == var_map_out_.end()) {
215 216 217 218 219 220 221 222 223 224 225 226
      return {};
    }
    std::vector<Variable*> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
    }

    return vec_res;
  }

 private:
J
Jiabin Yang 已提交
227 228
  const NameVarMap<VarType>& var_map_in_;
  const NameVarMap<VarType>& var_map_out_;
229
  const framework::AttributeMap& attrs_;
230
  const framework::AttributeMap& default_attrs_;
231 232 233 234
};

}  // namespace imperative
}  // namespace paddle