scope.cc 2.8 KB
Newer Older
朔-望's avatar
朔-望 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114

#include "scope.h"
#include <set>
#include <vector>

namespace paddle_mobile {
namespace framework {

Scope &Scope::NewScope() const {
  std::unique_lock<std::mutex> lock(mutex_);
  kids_.push_back(new Scope(this));
  return *kids_.back();
}

Variable *Scope::Var(const std::string &name) {
  auto *pvar = FindVarLocally(name);
  if (pvar != nullptr) {
    return pvar;
  };
  pvar = new Variable;
  vars_[name] = pvar;
  pvar->name_ = &(vars_.find(name)->first);
  return pvar;
}

//            Variable* Scope::Var(std::string* name) {
//                auto var_name = string::Sprintf("%p.%d", this, vars_.size());
//                if (name != nullptr) {
//                    *name = var_name;
//                }
//                return Var(var_name);
//            }

Variable *Scope::FindVar(const std::string &name) const {
  auto *pvar = FindVarLocally(name);
  if (pvar != nullptr) {
    return pvar;
  }
  return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
}

const Scope *Scope::FindScope(const Variable *var) const {
  for (auto &name_var : vars_) {
    if (name_var.second == var) {
      return this;
    }
  }
  return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}

void Scope::DropKids() {
  for (Scope *s : kids_) {
    delete s;
  }
  kids_.clear();
}

std::vector<std::string> Scope::LocalVarNames() const {
  std::vector<std::string> known_vars;
  known_vars.reserve(vars_.size());
  for (auto &name_var : vars_) {
    known_vars.emplace_back(name_var.first);
  }
  return known_vars;
}

void Scope::DeleteScope(Scope *scope) const {
  std::unique_lock<std::mutex> lock(mutex_);
  auto it = std::find(kids_.begin(), kids_.end(), scope);
  kids_.erase(it);
  delete scope;
  // deferent
}

void Scope::EraseVars(const std::vector<std::string> &var_names) {
  std::set<std::string> var_set(var_names.begin(), var_names.end());
  for (auto it = vars_.begin(); it != vars_.end();) {
    if (var_set.find(it->first) != var_set.end()) {
      delete it->second;
      it = vars_.erase(it);
    } else {
      ++it;
    }
  }
}

void Scope::Rename(const std::string &origin_name,
                   const std::string &new_name) const {
  auto origin_it = vars_.find(origin_name);
  if (origin_it == vars_.end()) {
    return;
  }
  auto new_it = vars_.find(new_name);
  if (new_it != vars_.end()) {
    return;
  }
  vars_[new_name] = origin_it->second;
  vars_.erase(origin_it);
}
//
//            std::string Scope::Rename(const std::string& origin_name) const {
//                auto var_name = string::Sprintf("%p.%d", this, vars_.size());
//                Rename(origin_name, var_name);
//                return var_name;
//            }

Variable *Scope::FindVarLocally(const std::string &name) const {
  auto it = vars_.find(name);
  if (it != vars_.end()) {
    return it->second;
  }
  return nullptr;
}

L
liuruilong 已提交
115 116
} // namespace framework
} // namespace paddle_mobile