提交 978d1288 编写于 作者: Y Yang Yu

Merge branch 'feature/fix_get_all_names' into feature/rnn_gradient_check

......@@ -74,17 +74,9 @@ void Scope::DropKids() {
kids_.clear();
}
std::vector<std::string> Scope::GetAllNames(bool recursive) const {
std::vector<std::string> known_vars(vars_.size());
if (recursive) {
for (auto& kid : kids_) {
auto kid_vars = kid->GetAllNames();
for (auto& p : kid_vars) {
known_vars.emplace_back(p);
}
}
}
std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
}
......
......@@ -66,7 +66,7 @@ class Scope {
void DropKids();
// enumerate all the variables current contains.
std::vector<std::string> GetAllNames(bool recursive = false) const;
std::vector<std::string> LocalVarNames() const;
// Rename variable to a new name
void Rename(const std::string& origin_name,
......
......@@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) {
Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v));
std::vector<std::string> ans = s.GetAllNames();
std::vector<std::string> ans = s.LocalVarNames();
std::string str;
for (auto& var : ans) {
str += var;
......
......@@ -491,7 +491,7 @@ class RecurrentGradOp : public RecurrentBase {
std::unordered_set<std::string> LocalVarNames(
const framework::Scope &scope) const {
return this->List2Set(scope.GetAllNames(false));
return this->List2Set(scope.LocalVarNames());
}
static std::vector<std::string> GradVarLists(
const std::vector<std::string> &var_names) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册