op_registry_test.cc 6.9 KB
Newer Older
1 2
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
Q
Qiao Longfei 已提交
3

4 5
namespace pd = paddle::framework;

6 7
namespace paddle {
namespace framework {
Y
Yu Yang 已提交
8
class CosineOp : public OperatorBase {
9
 public:
10
  DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
Y
Yu Yang 已提交
11
  void Run(const Scope& scope,
Y
Yu Yang 已提交
12
           const platform::DeviceContext& dev_ctx) const override {}
Y
Yu Yang 已提交
13
  void InferShape(const Scope& scope) const override {}
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of cosine op");
    AddOutput("output", "output of cosine op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
        .LargerThan(0.0);
    AddComment("This is cos op");
  }
};

Y
Yu Yang 已提交
29 30
class MyTestOp : public OperatorBase {
 public:
31
  DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
Y
Yu Yang 已提交
32 33
  void InferShape(const Scope& scope) const override {}
  void Run(const Scope& scope,
Y
Yu Yang 已提交
34
           const platform::DeviceContext& dev_ctx) const override {}
35 36 37 38 39 40
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
Y
Yu Yang 已提交
41 42
    AddInput("input", "input of cosine op").SetDuplicable();
    AddOutput("output", "output of cosine op").SetIntermediate();
43 44 45 46 47 48 49 50 51 52 53
    auto my_checker = [](int i) {
      PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
    };
    AddAttr<int>("test_attr", "a simple test attribute")
        .AddCustomChecker(my_checker);
    AddComment("This is my_test op");
  }
};
}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
54 55 56 57 58 59 60 61 62
static void ConstructVars(const std::string& param_name,
                          std::initializer_list<const char*> arguments,
                          paddle::framework::OpDesc::Var* var) {
  var->set_parameter(param_name);
  for (auto& arg_name : arguments) {
    *var->mutable_arguments()->Add() = arg_name;
  }
}

Y
Yu Yang 已提交
63 64 65 66 67
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
            paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
            paddle::framework::MyTestOpProtoAndCheckerMaker);

68 69 70
TEST(OpRegistry, CreateOp) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
Y
Yu Yang 已提交
71 72
  auto* input = op_desc.add_inputs();
  ConstructVars("input", {"aa"}, input);
Y
Yu Yang 已提交
73

Y
Yu Yang 已提交
74 75
  auto* output = op_desc.add_outputs();
  ConstructVars("output", {"bb"}, output);
76

Q
Qiao Longfei 已提交
77
  float scale = 3.3;
78 79 80
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
Q
Qiao Longfei 已提交
81
  attr->set_f(scale);
82

Y
Yu Yang 已提交
83
  std::shared_ptr<paddle::framework::OperatorBase> op =
84
      paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
85
  paddle::framework::Scope scope;
Y
Yu Yang 已提交
86 87
  paddle::platform::CPUDeviceContext dev_ctx;
  op->Run(scope, dev_ctx);
Q
Qiao Longfei 已提交
88 89
  float scale_get = op->GetAttr<float>("scale");
  ASSERT_EQ(scale_get, scale);
90 91 92 93 94
}

TEST(OpRegistry, IllegalAttr) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
Y
Yu Yang 已提交
95 96
  auto* input = op_desc.add_inputs();
  ConstructVars("input", {"aa"}, input);
Y
Yu Yang 已提交
97

Y
Yu Yang 已提交
98 99
  auto* output = op_desc.add_outputs();
  ConstructVars("output", {"bb"}, output);
100 101 102 103 104 105 106 107

  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
  attr->set_f(-2.0);

  bool caught = false;
  try {
Y
Yu Yang 已提交
108
    paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
109
  } catch (paddle::platform::EnforceNotMet err) {
110 111 112 113 114 115 116 117 118 119 120 121 122
    caught = true;
    std::string msg = "larger_than check fail";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);
}

TEST(OpRegistry, DefaultValue) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
Y
Yu Yang 已提交
123 124
  auto* input = op_desc.add_inputs();
  ConstructVars("input", {"aa"}, input);
Y
Yu Yang 已提交
125

Y
Yu Yang 已提交
126 127
  auto* output = op_desc.add_outputs();
  ConstructVars("output", {"bb"}, output);
128

Q
Qiao Longfei 已提交
129 130
  ASSERT_TRUE(op_desc.IsInitialized());

Y
Yu Yang 已提交
131
  std::shared_ptr<paddle::framework::OperatorBase> op =
132
      paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
133
  paddle::framework::Scope scope;
Y
Yu Yang 已提交
134 135
  paddle::platform::CPUDeviceContext dev_ctx;
  op->Run(scope, dev_ctx);
Q
Qiao Longfei 已提交
136
  ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
137 138 139 140 141
}

TEST(OpRegistry, CustomChecker) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("my_test_op");
Y
Yu Yang 已提交
142 143
  auto* input = op_desc.add_inputs();
  ConstructVars("input", {"ii"}, input);
Y
Yu Yang 已提交
144

Y
Yu Yang 已提交
145 146
  auto* output = op_desc.add_outputs();
  ConstructVars("output", {"oo"}, output);
147 148 149 150

  // attr 'test_attr' is not set
  bool caught = false;
  try {
Y
Yu Yang 已提交
151
    paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
152
  } catch (paddle::platform::EnforceNotMet err) {
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
    caught = true;
    std::string msg = "Attribute 'test_attr' is required!";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);

  // set 'test_attr' set to an illegal value
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("test_attr");
  attr->set_type(paddle::framework::AttrType::INT);
  attr->set_i(3);
  caught = false;
  try {
Y
Yu Yang 已提交
169
    paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
170
  } catch (paddle::platform::EnforceNotMet err) {
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    caught = true;
    std::string msg = "'test_attr' must be even!";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);

  // set 'test_attr' set to a legal value
  op_desc.mutable_attrs()->Clear();
  attr = op_desc.mutable_attrs()->Add();
  attr->set_name("test_attr");
  attr->set_type(paddle::framework::AttrType::INT);
  attr->set_i(4);
Y
Yu Yang 已提交
186
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
187
  paddle::platform::CPUDeviceContext dev_ctx;
Y
Yu Yang 已提交
188
  paddle::framework::Scope scope;
Y
Yu Yang 已提交
189
  op->Run(scope, dev_ctx);
Q
Qiao Longfei 已提交
190 191
  int test_attr = op->GetAttr<int>("test_attr");
  ASSERT_EQ(test_attr, 4);
D
dongzhihong 已提交
192
}
193 194 195 196 197 198 199 200 201 202 203 204 205 206

class TestAttrProtoMaker : public pd::OpProtoAndCheckerMaker {
 public:
  TestAttrProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddAttr<float>("scale", "scale of test op");
    AddAttr<float>("scale", "scale of test op");
  }
};

TEST(ProtoMaker, DuplicatedAttr) {
  pd::OpProto op_proto;
  pd::OpAttrChecker op_checker;
  auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
Y
Yu Yang 已提交
207
  ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
}

class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
 public:
  TestInOutProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of test op");
    AddInput("input", "input of test op");
  }
};

TEST(ProtoMaker, DuplicatedInOut) {
  pd::OpProto op_proto;
  pd::OpAttrChecker op_checker;
  auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
Y
Yu Yang 已提交
223
  ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
224
}