op_registry_test.cc 5.8 KB
Newer Older
1 2
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
Q
Qiao Longfei 已提交
3

4 5
namespace pd = paddle::framework;

6 7
namespace paddle {
namespace framework {
Y
Yu Yang 已提交
8
class CosineOp : public OperatorBase {
9
 public:
Y
Yu Yang 已提交
10
  using OperatorBase::OperatorBase;
D
dzhwinter 已提交
11
  void Run(const Scope& scope, const platform::Place& place) const override {}
12 13 14 15 16 17 18 19 20 21
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of cosine op");
    AddOutput("output", "output of cosine op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
22
        .GreaterThan(0.0);
23 24 25 26
    AddComment("This is cos op");
  }
};

Y
Yu Yang 已提交
27 28
class MyTestOp : public OperatorBase {
 public:
Y
Yu Yang 已提交
29
  using OperatorBase::OperatorBase;
D
dzhwinter 已提交
30
  void Run(const Scope& scope, const platform::Place& place) const override {}
31 32 33 34 35 36
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
Y
Yu Yang 已提交
37 38
    AddInput("input", "input of cosine op").AsDuplicable();
    AddOutput("output", "output of cosine op").AsIntermediate();
39 40 41 42 43 44 45 46 47 48 49
    auto my_checker = [](int i) {
      PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
    };
    AddAttr<int>("test_attr", "a simple test attribute")
        .AddCustomChecker(my_checker);
    AddComment("This is my_test op");
  }
};
}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
50 51
static void BuildVar(const std::string& param_name,
                     std::initializer_list<const char*> arguments,
52
                     paddle::framework::proto::OpDesc::Var* var) {
Y
Yu Yang 已提交
53 54
  var->set_parameter(param_name);
  for (auto& arg_name : arguments) {
Y
Yu Yang 已提交
55
    var->add_arguments(arg_name);
Y
Yu Yang 已提交
56 57
  }
}
F
fengjiayi 已提交
58 59 60 61
REGISTER_OP_WITHOUT_GRADIENT(cos_sim, paddle::framework::CosineOp,
                             paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
                             paddle::framework::MyTestOpProtoAndCheckerMaker);
Y
Yu Yang 已提交
62

63
TEST(OpRegistry, CreateOp) {
64
  paddle::framework::proto::OpDesc op_desc;
65
  op_desc.set_type("cos_sim");
Y
Yu Yang 已提交
66 67
  BuildVar("input", {"aa"}, op_desc.add_inputs());
  BuildVar("output", {"bb"}, op_desc.add_outputs());
68

Q
Qiao Longfei 已提交
69
  float scale = 3.3;
70 71
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
72
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
73
  attr->set_f(scale);
74

75
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
76
  paddle::framework::Scope scope;
D
dzhwinter 已提交
77 78
  paddle::platform::CPUPlace cpu_place;
  op->Run(scope, cpu_place);
Y
Yu Yang 已提交
79
  float scale_get = op->Attr<float>("scale");
Q
Qiao Longfei 已提交
80
  ASSERT_EQ(scale_get, scale);
81 82 83
}

TEST(OpRegistry, IllegalAttr) {
84
  paddle::framework::proto::OpDesc op_desc;
85
  op_desc.set_type("cos_sim");
Y
Yu Yang 已提交
86 87
  BuildVar("input", {"aa"}, op_desc.add_inputs());
  BuildVar("output", {"bb"}, op_desc.add_outputs());
88 89 90

  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
91
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
92 93 94 95
  attr->set_f(-2.0);

  bool caught = false;
  try {
96
    paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
97
  } catch (paddle::platform::EnforceNotMet err) {
98 99 100 101 102 103 104 105 106 107 108
    caught = true;
    std::string msg = "larger_than check fail";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);
}

TEST(OpRegistry, DefaultValue) {
109
  paddle::framework::proto::OpDesc op_desc;
110
  op_desc.set_type("cos_sim");
Y
Yu Yang 已提交
111 112
  BuildVar("input", {"aa"}, op_desc.add_inputs());
  BuildVar("output", {"bb"}, op_desc.add_outputs());
113

Q
Qiao Longfei 已提交
114 115
  ASSERT_TRUE(op_desc.IsInitialized());

116
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
117
  paddle::framework::Scope scope;
D
dzhwinter 已提交
118 119
  paddle::platform::CPUPlace cpu_place;
  op->Run(scope, cpu_place);
Y
Yu Yang 已提交
120
  ASSERT_EQ(op->Attr<float>("scale"), 1.0);
121 122 123
}

TEST(OpRegistry, CustomChecker) {
124
  paddle::framework::proto::OpDesc op_desc;
125
  op_desc.set_type("my_test_op");
Y
Yu Yang 已提交
126 127
  BuildVar("input", {"ii"}, op_desc.add_inputs());
  BuildVar("output", {"oo"}, op_desc.add_outputs());
128 129 130 131

  // attr 'test_attr' is not set
  bool caught = false;
  try {
132
    paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
133
  } catch (paddle::platform::EnforceNotMet err) {
134 135 136 137 138 139 140 141 142 143 144 145
    caught = true;
    std::string msg = "Attribute 'test_attr' is required!";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);

  // set 'test_attr' set to an illegal value
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("test_attr");
146
  attr->set_type(paddle::framework::proto::AttrType::INT);
147 148 149
  attr->set_i(3);
  caught = false;
  try {
150
    paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
151
  } catch (paddle::platform::EnforceNotMet err) {
152 153 154 155 156 157 158 159 160 161 162 163 164
    caught = true;
    std::string msg = "'test_attr' must be even!";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);

  // set 'test_attr' set to a legal value
  op_desc.mutable_attrs()->Clear();
  attr = op_desc.mutable_attrs()->Add();
  attr->set_name("test_attr");
165
  attr->set_type(paddle::framework::proto::AttrType::INT);
166
  attr->set_i(4);
167
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dzhwinter 已提交
168
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
169
  paddle::framework::Scope scope;
D
dzhwinter 已提交
170
  op->Run(scope, cpu_place);
Y
Yu Yang 已提交
171
  int test_attr = op->Attr<int>("test_attr");
Q
Qiao Longfei 已提交
172
  ASSERT_EQ(test_attr, 4);
Q
Qiao Longfei 已提交
173
}
174 175 176 177 178 179 180 181 182 183

class CosineOpComplete : public paddle::framework::CosineOp {
 public:
  DEFINE_OP_CONSTRUCTOR(CosineOpComplete, paddle::framework::CosineOp);
  DEFINE_OP_CLONE_METHOD(CosineOpComplete);
};

TEST(OperatorRegistrar, Test) {
  using namespace paddle::framework;
  OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos");
184
}