execution_context.h 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/type_defs.h"
J
Jiabin Yang 已提交
23
#include "paddle/fluid/imperative/var_helper.h"
24 25 26 27 28 29 30 31 32 33 34 35 36

namespace paddle {
namespace imperative {

template <typename VarType>
class DygraphExecutionContext : public framework::ExecutionContext {
  using Variable = framework::Variable;

 public:
  DygraphExecutionContext(const framework::OperatorBase& op,
                          const framework::Scope& scope,
                          const platform::DeviceContext& device_context,
                          const framework::RuntimeContext& ctx,
J
Jiabin Yang 已提交
37 38
                          const NameVarMap<VarType>& var_map_in,
                          const NameVarMap<VarType>& var_map_out,
39 40
                          const framework::AttributeMap& attrs,
                          const framework::AttributeMap& default_attrs)
41
      : ExecutionContext(op, scope, device_context, ctx),
J
Jiabin Yang 已提交
42 43
        var_map_in_(var_map_in),
        var_map_out_(var_map_out),
44 45
        attrs_(attrs),
        default_attrs_(default_attrs) {}
46 47

  std::string InputName(const std::string& name) const override {
J
Jiabin Yang 已提交
48 49
    auto it = var_map_in_.find(name);
    PADDLE_ENFORCE_NE(it, var_map_in_.end(),
50 51
                      platform::errors::PreconditionNotMet(
                          "Can not find [%s] in Input", name));
J
Jiabin Yang 已提交
52 53
    return it->second[0] ? GetNameFromVar(it->second[0])
                         : framework::kEmptyVarName;
54 55 56
  }

  std::vector<std::string> InputNames(const std::string& name) const override {
J
Jiabin Yang 已提交
57
    auto it = var_map_in_.find(name);
58
    PADDLE_ENFORCE_NE(
J
Jiabin Yang 已提交
59
        it, var_map_in_.end(),
60 61 62 63 64
        platform::errors::NotFound("Can not find [%s] in Input", name));
    std::vector<std::string> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      if (it->second[i]) {
J
Jiabin Yang 已提交
65
        vec_res.push_back(GetNameFromVar(it->second[i]));
66 67 68 69 70 71 72 73
      } else {
        vec_res.push_back(framework::kEmptyVarName);
      }
    }
    return vec_res;
  }

  std::string OutputName(const std::string& name) const override {
J
Jiabin Yang 已提交
74
    auto it = var_map_out_.find(name);
75
    PADDLE_ENFORCE_NE(
J
Jiabin Yang 已提交
76
        it, var_map_out_.end(),
77
        platform::errors::NotFound("Can not find [%s] in Output", name));
J
Jiabin Yang 已提交
78 79
    return it->second[0] ? GetNameFromVar(it->second[0])
                         : framework::kEmptyVarName;
80 81 82
  }

  std::vector<std::string> OutputNames(const std::string& name) const override {
J
Jiabin Yang 已提交
83
    auto it = var_map_out_.find(name);
84
    PADDLE_ENFORCE_NE(
J
Jiabin Yang 已提交
85
        it, var_map_out_.end(),
86 87 88 89 90
        platform::errors::NotFound("Can not find [%s] in Output", name));
    std::vector<std::string> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      if (it->second[i]) {
J
Jiabin Yang 已提交
91
        vec_res.push_back(GetNameFromVar(it->second[i]));
92 93 94 95 96 97 98 99
      } else {
        vec_res.push_back(framework::kEmptyVarName);
      }
    }
    return vec_res;
  }

  bool HasAttr(const std::string& name) const override {
100
    return attrs_.count(name) != 0 || default_attrs_.count(name) != 0;
101 102 103 104 105 106 107
  }

  const framework::AttributeMap& Attrs() const override { return attrs_; }

  const framework::Attribute& GetAttr(const std::string& name) const override {
    auto it = attrs_.find(name);

108 109 110 111 112 113 114 115
    if (it == attrs_.end()) {
      it = default_attrs_.find(name);
      if (it == default_attrs_.end()) {
        PADDLE_THROW(platform::errors::NotFound(
            "Can not find [%s] in attributes of op %s.", name,
            this->GetOp().Type()));
      }
    }
116 117 118 119

    return it->second;
  }

C
Chen Weihang 已提交
120 121
  paddle::small_vector<const std::string*> InNameList() const override {
    paddle::small_vector<const std::string*> vec_temp;
J
Jiabin Yang 已提交
122
    vec_temp.reserve(var_map_in_.size());
123

J
Jiabin Yang 已提交
124
    for (auto& v : var_map_in_) {
125
      vec_temp.push_back(&v.first);
126 127 128 129 130 131
    }

    return vec_temp;
  }

  bool HasInput(const std::string& name) const override {
J
Jiabin Yang 已提交
132 133
    auto it = var_map_in_.find(name);
    return (it != var_map_in_.end() && it->second.size() > 0);
134 135
  }

136 137 138 139 140
  bool HasInputs(const std::string& name) const override {
    auto it = var_map_in_.find(name);
    return (it != var_map_in_.end() && it->second.size() > 0);
  }

141
  bool HasOutput(const std::string& name) const override {
J
Jiabin Yang 已提交
142 143
    auto it = var_map_out_.find(name);
    return (it != var_map_out_.end() && it->second.size() > 0);
144 145 146
  }

  size_t InputSize(const std::string& name) const override {
147 148 149 150 151
    auto it = var_map_in_.find(name);
    PADDLE_ENFORCE_NE(
        it, var_map_in_.end(),
        platform::errors::NotFound("Can not find [%s] in Input", name));
    return it->second.size();
152 153 154
  }

  size_t OutputSize(const std::string& name) const override {
155 156 157 158 159
    auto it = var_map_out_.find(name);
    PADDLE_ENFORCE_NE(
        it, var_map_out_.end(),
        platform::errors::NotFound("Can not find [%s] in Output", name));
    return it->second.size();
160 161 162
  }

  const Variable* InputVar(const std::string& name) const override {
J
Jiabin Yang 已提交
163 164
    auto it = var_map_in_.find(name);
    if (it == var_map_in_.end()) {
165 166 167 168 169 170 171 172 173
      return nullptr;
    }

    return it->second.empty() || it->second[0] == nullptr
               ? nullptr
               : it->second[0]->MutableVar();
  }

  Variable* OutputVar(const std::string& name) const override {
J
Jiabin Yang 已提交
174 175
    auto it = var_map_out_.find(name);
    if (it == var_map_out_.end()) {
176 177 178 179 180 181 182 183 184 185
      return nullptr;
    }

    return it->second.empty() || it->second[0] == nullptr
               ? nullptr
               : it->second[0]->MutableVar();
  }

  const std::vector<Variable*> MultiInputVar(
      const std::string& name) const override {
J
Jiabin Yang 已提交
186 187
    auto it = var_map_in_.find(name);
    if (it == var_map_in_.end()) {
188 189 190 191 192 193 194 195 196 197 198 199 200
      return {};
    }
    std::vector<Variable*> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
    }

    return vec_res;
  }

  std::vector<Variable*> MultiOutputVar(
      const std::string& name) const override {
J
Jiabin Yang 已提交
201 202
    auto it = var_map_out_.find(name);
    if (it == var_map_out_.end()) {
203 204 205 206 207 208 209 210 211 212 213 214
      return {};
    }
    std::vector<Variable*> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
    }

    return vec_res;
  }

 private:
J
Jiabin Yang 已提交
215 216
  const NameVarMap<VarType>& var_map_in_;
  const NameVarMap<VarType>& var_map_out_;
217
  const framework::AttributeMap& attrs_;
218
  const framework::AttributeMap& default_attrs_;
219 220 221 222
};

}  // namespace imperative
}  // namespace paddle