net_op_test.cc 3.4 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yi Wang 已提交
14
#include "paddle/fluid/operators/net_op.h"
Y
Yan Chunwei 已提交
15

Q
Qiao Longfei 已提交
16
#include <gtest/gtest.h>
Y
Yan Chunwei 已提交
17

D
dongzhihong 已提交
18
namespace paddle {
Y
Yan Chunwei 已提交
19
namespace operators {
D
dongzhihong 已提交
20 21
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
Q
Qiao Longfei 已提交
22 23 24

static int run_cnt = 0;

D
dongzhihong 已提交
25
class TestOp : public framework::OperatorBase {
26
 public:
Y
Yu Yang 已提交
27
  using framework::OperatorBase::OperatorBase;
Y
Yu Yang 已提交
28
  DEFINE_OP_CLONE_METHOD(TestOp);
29 30 31 32

 private:
  void RunImpl(const Scope& scope,
               const platform::Place& place) const override {
Q
Qiao Longfei 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    ++run_cnt;
  }
};

template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
                                  const std::vector<T>& actual) {
  ASSERT_EQ(expected.size(), actual.size());
  std::unordered_set<T> expected_set;
  for (auto& tmp : expected) {
    expected_set.insert(tmp);
  }
  for (auto& act : actual) {
    ASSERT_NE(expected_set.end(), expected_set.find(act));
  }
}

TEST(OpKernel, all) {
Y
Yu Yang 已提交
51
  auto net = std::make_shared<NetOp>();
D
dongzhihong 已提交
52 53
  ASSERT_NE(net, nullptr);

Y
Yu Yang 已提交
54
  net->AppendOp(std::unique_ptr<TestOp>(
Y
Yu Yang 已提交
55
      new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
Y
Yiqun Liu 已提交
56
                 {{"Out", {"y"}}}, framework::AttributeMap{})));
Y
Yu Yang 已提交
57
  net->AppendOp(std::unique_ptr<TestOp>(
Y
Yu Yang 已提交
58
      new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
Y
Yiqun Liu 已提交
59
                 {{"Out", {"z"}}}, framework::AttributeMap{})));
D
dongzhihong 已提交
60 61

  net->CompleteAddOp();
Y
Yu Yang 已提交
62
  AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
Q
qiaolongfei 已提交
63 64
                               net->Inputs(NetOp::kAll));
  AssertSameVectorWithoutOrder({"y", "z"}, net->Outputs(NetOp::kAll));
D
dongzhihong 已提交
65

66
  auto final_outs = net->OutputVars(false);
D
dongzhihong 已提交
67

68 69
  ASSERT_EQ(final_outs.size(), 1UL);
  ASSERT_EQ(final_outs[0], "z");
D
dongzhihong 已提交
70 71
}

Y
Yan Chunwei 已提交
72
TEST(NetOp, insert_op) {
Y
Yu Yang 已提交
73
  NetOp net;
Y
Yu Yang 已提交
74
  auto op1 = std::unique_ptr<framework::NOP>(
F
fengjiayi 已提交
75
      new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
Y
Yiqun Liu 已提交
76
                         {{"Out", {"y"}}}, framework::AttributeMap{}));
Y
Yu Yang 已提交
77
  net.AppendOp(*op1);
Y
Yu Yang 已提交
78
  net.InsertOp(0, *op1);
Y
Yu Yang 已提交
79
  ASSERT_EQ(2UL, net.ops_.size());
Y
Yu Yang 已提交
80
  net.InsertOp(2, std::move(op1));
Y
Yu Yang 已提交
81 82
  ASSERT_EQ(3UL, net.ops_.size());
}
D
dongzhihong 已提交
83

Y
Yu Yang 已提交
84 85
TEST(NetOp, Clone) {
  NetOp net;
Y
Yiqun Liu 已提交
86 87 88 89 90 91
  net.AppendOp(std::unique_ptr<framework::NOP>(new framework::NOP{
      "empty", framework::VariableNameMap{}, framework::VariableNameMap{},
      framework::AttributeMap{}}));
  net.AppendOp(std::unique_ptr<framework::NOP>(new framework::NOP{
      "empty2", framework::VariableNameMap{}, framework::VariableNameMap{},
      framework::AttributeMap{}}));
Y
Yu Yang 已提交
92
  net.CompleteAddOp(true);
Y
Yu Yang 已提交
93
  auto new_net_op = net.Clone();
Y
Yu Yang 已提交
94 95
  ASSERT_NE(new_net_op, nullptr);
  ASSERT_TRUE(new_net_op->IsNetOp());
Y
Yu Yang 已提交
96
  auto* new_net = static_cast<NetOp*>(new_net_op.get());
Q
qiaolongfei 已提交
97
  ASSERT_EQ(2UL, new_net->ops_.size());
Y
Yu Yang 已提交
98 99 100 101
  ASSERT_EQ(new_net->ops_[0]->Type(), "empty");
  ASSERT_EQ(new_net->ops_[1]->Type(), "empty2");
}

Y
Yan Chunwei 已提交
102
}  // namespace operators
D
dongzhihong 已提交
103
}  // namespace paddle