net_op_test.cc 2.5 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5
#include <gtest/gtest.h>
#include <paddle/framework/net.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>

D
dongzhihong 已提交
6 7
namespace paddle {
namespace framework {
Q
Qiao Longfei 已提交
8 9 10 11

static int infer_shape_cnt = 0;
static int run_cnt = 0;

D
dongzhihong 已提交
12
class TestOp : public OperatorBase {
Q
Qiao Longfei 已提交
13
 public:
D
dongzhihong 已提交
14 15 16
  void InferShape(const ScopePtr& scope) const override { ++infer_shape_cnt; }
  void Run(const ScopePtr& scope,
           const platform::DeviceContext& dev_ctx) const override {
Q
Qiao Longfei 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
    ++run_cnt;
  }
};

template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
                                  const std::vector<T>& actual) {
  ASSERT_EQ(expected.size(), actual.size());
  std::unordered_set<T> expected_set;
  for (auto& tmp : expected) {
    expected_set.insert(tmp);
  }
  for (auto& act : actual) {
    ASSERT_NE(expected_set.end(), expected_set.find(act));
  }
}

D
dongzhihong 已提交
34
class PlainNetTest : public testing::Test {
D
dongzhihong 已提交
35
 public:
D
dongzhihong 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
  virtual void SetUp() {
    net_ = std::make_shared<PlainNet>();
    ASSERT_NE(net_, nullptr);

    auto op1 = std::make_shared<TestOp>();
    op1->inputs_ = {"x", "w1", "b1"};
    op1->outputs_ = {"y"};
    net_->AddOp(op1);

    auto op2 = std::make_shared<TestOp>();
    op2->inputs_ = {"y", "w2", "b2"};
    op2->outputs_ = {"z"};
    net_->AddOp(op2);
    net_->CompleteAddOp();
  }

  virtual void TearDown() {}

D
dongzhihong 已提交
54 55
  virtual void TestBody() {}

D
dongzhihong 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
  void TestOpKernel() {
    AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net_->inputs_);
    AssertSameVectorWithoutOrder({"y", "z"}, net_->outputs_);
    auto tmp_idx_iter = net_->attrs_.find("temporary_index");
    ASSERT_NE(net_->attrs_.end(), tmp_idx_iter);
    auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second);
    ASSERT_EQ(1UL, tmp_idx.size());
    ASSERT_EQ("y", net_->outputs_[tmp_idx[0]]);

    auto scope = std::make_shared<Scope>();
    platform::CPUDeviceContext dev_ctx;

    net_->InferShape(scope);
    net_->Run(scope, dev_ctx);
    ASSERT_EQ(2, infer_shape_cnt);
    ASSERT_EQ(2, run_cnt);

D
dongzhihong 已提交
73
    auto op2 = std::make_shared<TestOp>();
D
dongzhihong 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87
    ASSERT_THROW(net_->AddOp(op2), EnforceNotMet);
  }

  void TestAddBackwardOp() {
    auto grad_ops = AddBackwardOp(net_);
    for (auto& op : grad_ops->ops_) {
      op->DebugString();
    }
  }

 private:
  std::shared_ptr<PlainNet> net_;
};

Q
Qiao Longfei 已提交
88
TEST(OpKernel, all) {
D
dongzhihong 已提交
89
  PlainNetTest net;
D
dongzhihong 已提交
90
  net.TestOpKernel();
D
dongzhihong 已提交
91 92 93 94
}

TEST(AddBackwardOp, TestAddBackwardOp) {
  PlainNetTest net;
D
dongzhihong 已提交
95
  net.TestAddBackwardOp();
Q
Qiao Longfei 已提交
96
}
D
dongzhihong 已提交
97 98 99

}  // namespace framework
}  // namespace paddle