scope_test.cc 1.6 KB
Newer Older
Q
qiaolongfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Q
qiaolongfei 已提交
15
#include "paddle/framework/scope.h"
16
#include "glog/logging.h"
Q
qiaolongfei 已提交
17 18
#include "gtest/gtest.h"

Y
Yi Wang 已提交
19 20
using paddle::framework::Scope;
using paddle::framework::Variable;
Q
qiaolongfei 已提交
21

Y
Yi Wang 已提交
22 23 24 25
TEST(Scope, VarsShadowing) {
  Scope s;
  Scope& ss1 = s.NewScope();
  Scope& ss2 = s.NewScope();
Q
qiaolongfei 已提交
26

D
dongzhihong 已提交
27 28
  Variable* v0 = s.Var("a");
  Variable* v1 = ss1.Var("a");
Q
qiaolongfei 已提交
29

Y
Yi Wang 已提交
30
  EXPECT_NE(v0, v1);
Q
qiaolongfei 已提交
31

Y
Yi Wang 已提交
32 33 34 35
  EXPECT_EQ(v0, s.FindVar("a"));
  EXPECT_EQ(v1, ss1.FindVar("a"));
  EXPECT_EQ(v0, ss2.FindVar("a"));
}
Q
qiaolongfei 已提交
36

Y
Yi Wang 已提交
37 38 39
TEST(Scope, FindVar) {
  Scope s;
  Scope& ss = s.NewScope();
40

Y
Yi Wang 已提交
41 42
  EXPECT_EQ(nullptr, s.FindVar("a"));
  EXPECT_EQ(nullptr, ss.FindVar("a"));
Q
qiaolongfei 已提交
43

D
dongzhihong 已提交
44
  ss.Var("a");
Q
qiaolongfei 已提交
45

Y
Yi Wang 已提交
46 47 48
  EXPECT_EQ(nullptr, s.FindVar("a"));
  EXPECT_NE(nullptr, ss.FindVar("a"));
}
Q
qiaolongfei 已提交
49

Y
Yi Wang 已提交
50 51 52
TEST(Scope, FindScope) {
  Scope s;
  Scope& ss = s.NewScope();
D
dongzhihong 已提交
53
  Variable* v = s.Var("a");
Q
qiaolongfei 已提交
54

Y
Yi Wang 已提交
55 56
  EXPECT_EQ(&s, s.FindScope(v));
  EXPECT_EQ(&s, ss.FindScope(v));
57
}
58 59 60 61 62 63 64 65 66 67 68 69 70 71

TEST(Scope, GetAllNames) {
  Scope s;
  Variable* v = s.Var("a");
  EXPECT_EQ(&s, s.FindScope(v));

  std::vector<std::string> ans = s.GetAllNames();
  std::string str;
  for (auto& var : ans) {
    str += var;
  }

  EXPECT_STREQ("a", str.c_str());
}