net_op_test.cc 2.8 KB
Newer Older
Y
Yan Chunwei 已提交
1 2
#include "paddle/operators/net_op.h"

Q
Qiao Longfei 已提交
3
#include <gtest/gtest.h>
Y
Yan Chunwei 已提交
4

D
dongzhihong 已提交
5
namespace paddle {
Y
Yan Chunwei 已提交
6
namespace operators {
D
dongzhihong 已提交
7 8
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
Q
Qiao Longfei 已提交
9 10 11 12

static int infer_shape_cnt = 0;
static int run_cnt = 0;

D
dongzhihong 已提交
13
class TestOp : public framework::OperatorBase {
14
 public:
Y
Yu Yang 已提交
15
  using framework::OperatorBase::OperatorBase;
Y
Yu Yang 已提交
16
  DEFINE_OP_CLONE_METHOD(TestOp);
D
dongzhihong 已提交
17 18 19
  void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
  void Run(const Scope& scope,
           const platform::DeviceContext& dev_ctx) const override {
Q
Qiao Longfei 已提交
20 21 22 23
    ++run_cnt;
  }
};

D
dongzhihong 已提交
24
class EmptyOp : public framework::OperatorBase {
25
 public:
Y
Yu Yang 已提交
26
  using framework::OperatorBase::OperatorBase;
Y
Yu Yang 已提交
27
  DEFINE_OP_CLONE_METHOD(EmptyOp);
Y
Yu Yang 已提交
28
  void InferShape(const Scope& scope) const override {}
D
dongzhihong 已提交
29
  void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
Y
Yu Yang 已提交
30 31
};

Q
Qiao Longfei 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45
template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
                                  const std::vector<T>& actual) {
  ASSERT_EQ(expected.size(), actual.size());
  std::unordered_set<T> expected_set;
  for (auto& tmp : expected) {
    expected_set.insert(tmp);
  }
  for (auto& act : actual) {
    ASSERT_NE(expected_set.end(), expected_set.find(act));
  }
}

TEST(OpKernel, all) {
Y
Yu Yang 已提交
46
  auto net = std::make_shared<NetOp>();
D
dongzhihong 已提交
47 48
  ASSERT_NE(net, nullptr);

Y
Yu Yang 已提交
49 50 51
  auto op1 = std::shared_ptr<TestOp>(
      new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
                 {{"Out", {"y"}}}, {}));
D
dongzhihong 已提交
52 53
  net->AddOp(op1);

Y
Yu Yang 已提交
54 55 56
  auto op2 = std::shared_ptr<TestOp>(
      new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
                 {{"Out", {"z"}}}, {}));
D
dongzhihong 已提交
57 58 59
  net->AddOp(op2);

  net->CompleteAddOp();
Y
Yu Yang 已提交
60
  AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
61 62
                               net->inputs_.at(NetOp::kAll));
  AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at(NetOp::kAll));
D
dongzhihong 已提交
63

64
  auto final_outs = net->OutputVars(false);
D
dongzhihong 已提交
65

66 67
  ASSERT_EQ(final_outs.size(), 1UL);
  ASSERT_EQ(final_outs[0], "z");
D
dongzhihong 已提交
68 69
}

Y
Yan Chunwei 已提交
70
TEST(NetOp, insert_op) {
Y
Yu Yang 已提交
71
  NetOp net;
Y
Yu Yang 已提交
72 73 74
  auto op1 = std::shared_ptr<EmptyOp>(
      new EmptyOp("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
                  {{"Out", {"y"}}}, {}));
Y
Yu Yang 已提交
75 76 77 78 79 80
  net.AddOp(op1);
  net.InsertOp(0, op1);
  ASSERT_EQ(2UL, net.ops_.size());
  net.InsertOp(2, op1);
  ASSERT_EQ(3UL, net.ops_.size());
}
D
dongzhihong 已提交
81

Y
Yu Yang 已提交
82 83 84 85 86
TEST(NetOp, Clone) {
  NetOp net;
  net.AddOp(std::shared_ptr<EmptyOp>(new EmptyOp{"empty", {}, {}, {}}));
  net.AddOp(std::shared_ptr<EmptyOp>(new EmptyOp{"empty2", {}, {}, {}}));
  net.CompleteAddOp(true);
Y
Yu Yang 已提交
87
  auto new_net_op = net.Clone();
Y
Yu Yang 已提交
88 89
  ASSERT_NE(new_net_op, nullptr);
  ASSERT_TRUE(new_net_op->IsNetOp());
Y
Yu Yang 已提交
90
  auto* new_net = static_cast<NetOp*>(new_net_op.get());
Y
Yu Yang 已提交
91 92 93 94 95
  ASSERT_EQ(2, new_net->ops_.size());
  ASSERT_EQ(new_net->ops_[0]->Type(), "empty");
  ASSERT_EQ(new_net->ops_[1]->Type(), "empty2");
}

Y
Yan Chunwei 已提交
96
}  // namespace operators
D
dongzhihong 已提交
97
}  // namespace paddle