net_op_test.cc 2.5 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4
#include <gtest/gtest.h>
#include <paddle/framework/net.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
D
dongzhihong 已提交
5 6 7 8

USE_OP(add_two);
USE_OP(mul);
USE_OP(sigmoid);
D
dongzhihong 已提交
9
USE_OP(softmax);
Q
Qiao Longfei 已提交
10

D
dongzhihong 已提交
11 12
namespace paddle {
namespace framework {
Q
Qiao Longfei 已提交
13 14 15 16

static int infer_shape_cnt = 0;
static int run_cnt = 0;

D
dongzhihong 已提交
17
class TestOp : public OperatorBase {
Q
Qiao Longfei 已提交
18
 public:
Y
Yu Yang 已提交
19
  void InferShape(const framework::Scope& scope) const override {
Q
Qiao Longfei 已提交
20 21
    ++infer_shape_cnt;
  }
Y
Yu Yang 已提交
22
  void Run(const framework::Scope& scope,
Q
Qiao Longfei 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
           const paddle::platform::DeviceContext& dev_ctx) const override {
    ++run_cnt;
  }
};

template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
                                  const std::vector<T>& actual) {
  ASSERT_EQ(expected.size(), actual.size());
  std::unordered_set<T> expected_set;
  for (auto& tmp : expected) {
    expected_set.insert(tmp);
  }
  for (auto& act : actual) {
    ASSERT_NE(expected_set.end(), expected_set.find(act));
  }
}

TEST(OpKernel, all) {
Y
Yu Yang 已提交
42
  auto net = std::make_shared<NetOp>();
D
dongzhihong 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
  ASSERT_NE(net, nullptr);

  auto op1 = std::make_shared<TestOp>();
  op1->inputs_ = {"x", "w1", "b1"};
  op1->outputs_ = {"y"};
  net->AddOp(op1);

  auto op2 = std::make_shared<TestOp>();
  op2->inputs_ = {"y", "w2", "b2"};
  op2->outputs_ = {"z"};
  net->AddOp(op2);

  net->CompleteAddOp();
  AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_);
  AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_);
  auto tmp_idx_iter = net->attrs_.find("temporary_index");
  ASSERT_NE(net->attrs_.end(), tmp_idx_iter);
  auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second);
  ASSERT_EQ(1UL, tmp_idx.size());
  ASSERT_EQ("y", net->outputs_[tmp_idx[0]]);

Y
Yu Yang 已提交
64
  Scope scope;
D
dongzhihong 已提交
65 66 67 68 69 70
  platform::CPUDeviceContext dev_ctx;

  net->InferShape(scope);
  net->Run(scope, dev_ctx);
  ASSERT_EQ(2, infer_shape_cnt);
  ASSERT_EQ(2, run_cnt);
Y
Yu Yang 已提交
71
  ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
D
dongzhihong 已提交
72 73
}

Y
Yu Yang 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87
//! TODO(yuyang18): Refine Backward Op.
// TEST(AddBackwardOp, TestGradOp) {
//  auto net = std::make_shared<NetOp>();
//  ASSERT_NE(net, nullptr);
//  net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
//  net->AddOp(
//      framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
//  net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
//  {}));
//  auto grad_ops = AddBackwardOp(net);
//  for (auto& op : grad_ops->ops_) {
//    op->DebugString();
//  }
//}
D
dongzhihong 已提交
88

D
dongzhihong 已提交
89 90
}  // namespace framework
}  // namespace paddle