net_op_test.cc 2.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2
#include "paddle/operators/net_op.h"

Q
Qiao Longfei 已提交
3
#include <gtest/gtest.h>
Y
Yan Chunwei 已提交
4 5 6

#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
D
dongzhihong 已提交
7

D
dongzhihong 已提交
8
namespace paddle {
Y
Yan Chunwei 已提交
9
namespace operators {
Q
Qiao Longfei 已提交
10 11 12 13

static int infer_shape_cnt = 0;
static int run_cnt = 0;

D
dongzhihong 已提交
14
class TestOp : public OperatorBase {
15
 public:
Y
Yu Yang 已提交
16
  void InferShape(const framework::Scope& scope) const override {
Q
Qiao Longfei 已提交
17 18
    ++infer_shape_cnt;
  }
Y
Yu Yang 已提交
19
  void Run(const framework::Scope& scope,
Q
Qiao Longfei 已提交
20 21 22 23 24
           const paddle::platform::DeviceContext& dev_ctx) const override {
    ++run_cnt;
  }
};

Y
Yu Yang 已提交
25
class EmptyOp : public OperatorBase {
26
 public:
Y
Yu Yang 已提交
27 28
  void InferShape(const Scope& scope) const override {}
  void Run(const Scope& scope,
Y
Yu Yang 已提交
29 30 31
           const platform::DeviceContext& dev_ctx) const override {}
};

Q
Qiao Longfei 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45
template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
                                  const std::vector<T>& actual) {
  ASSERT_EQ(expected.size(), actual.size());
  std::unordered_set<T> expected_set;
  for (auto& tmp : expected) {
    expected_set.insert(tmp);
  }
  for (auto& act : actual) {
    ASSERT_NE(expected_set.end(), expected_set.find(act));
  }
}

TEST(OpKernel, all) {
Y
Yu Yang 已提交
46
  auto net = std::make_shared<NetOp>();
D
dongzhihong 已提交
47 48 49
  ASSERT_NE(net, nullptr);

  auto op1 = std::make_shared<TestOp>();
Y
Yu Yang 已提交
50 51
  op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}};
  op1->outputs_ = {{"Out", {"y"}}};
D
dongzhihong 已提交
52 53 54
  net->AddOp(op1);

  auto op2 = std::make_shared<TestOp>();
Y
Yu Yang 已提交
55 56
  op2->inputs_ = {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}};
  op2->outputs_ = {{"Out", {"z"}}};
D
dongzhihong 已提交
57 58 59
  net->AddOp(op2);

  net->CompleteAddOp();
Y
Yu Yang 已提交
60 61 62
  AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
                               net->inputs_.at("__all__"));
  AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at("__all__"));
D
dongzhihong 已提交
63 64 65 66
  auto tmp_idx_iter = net->attrs_.find("temporary_index");
  ASSERT_NE(net->attrs_.end(), tmp_idx_iter);
  auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second);
  ASSERT_EQ(1UL, tmp_idx.size());
Y
Yu Yang 已提交
67
  ASSERT_EQ("y", net->outputs_.at("__all__")[tmp_idx[0]]);
D
dongzhihong 已提交
68

Y
Yu Yang 已提交
69
  Scope scope;
D
dongzhihong 已提交
70 71 72 73 74 75
  platform::CPUDeviceContext dev_ctx;

  net->InferShape(scope);
  net->Run(scope, dev_ctx);
  ASSERT_EQ(2, infer_shape_cnt);
  ASSERT_EQ(2, run_cnt);
Y
Yu Yang 已提交
76
  ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
D
dongzhihong 已提交
77 78
}

Y
Yan Chunwei 已提交
79
TEST(NetOp, insert_op) {
Y
Yu Yang 已提交
80 81
  NetOp net;
  auto op1 = std::make_shared<EmptyOp>();
Y
Yu Yang 已提交
82 83
  op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}};
  op1->outputs_ = {{"Out", {"y"}}};
Y
Yu Yang 已提交
84 85 86 87 88 89
  net.AddOp(op1);
  net.InsertOp(0, op1);
  ASSERT_EQ(2UL, net.ops_.size());
  net.InsertOp(2, op1);
  ASSERT_EQ(3UL, net.ops_.size());
}
D
dongzhihong 已提交
90

Y
Yan Chunwei 已提交
91
}  // namespace operators
D
dongzhihong 已提交
92
}  // namespace paddle