op_registry_test.cc 6.7 KB
Newer Older
1 2
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
Q
Qiao Longfei 已提交
3

4 5
namespace pd = paddle::framework;

6 7
namespace paddle {
namespace framework {
Y
Yu Yang 已提交
8
class CosineOp : public OperatorBase {
9
 public:
Q
Qiao Longfei 已提交
10
  void Run(const ScopePtr& scope,
Y
Yu Yang 已提交
11
           const platform::DeviceContext& dev_ctx) const override {}
Q
Qiao Longfei 已提交
12
  void InferShape(const ScopePtr& scope) const override {}
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of cosine op");
    AddOutput("output", "output of cosine op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
        .LargerThan(0.0);
    AddComment("This is cos op");
  }
};

Y
Yu Yang 已提交
28 29
class MyTestOp : public OperatorBase {
 public:
Q
Qiao Longfei 已提交
30 31
  void InferShape(const ScopePtr& scope) const override {}
  void Run(const ScopePtr& scope,
Y
Yu Yang 已提交
32
           const platform::DeviceContext& dev_ctx) const override {}
33 34 35 36 37 38
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
39 40 41
    AddInputs("input", "input of cosine op");
    AddOutput("output", "output of cosine op",
              /*temporary*/ true);
42 43 44 45 46 47 48 49 50 51 52
    auto my_checker = [](int i) {
      PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
    };
    AddAttr<int>("test_attr", "a simple test attribute")
        .AddCustomChecker(my_checker);
    AddComment("This is my_test op");
  }
};
}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
53 54 55 56 57
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
            paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
            paddle::framework::MyTestOpProtoAndCheckerMaker);

58 59 60 61 62 63
TEST(OpRegistry, CreateOp) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
  op_desc.add_inputs("aa");
  op_desc.add_outputs("bb");

Q
Qiao Longfei 已提交
64
  float scale = 3.3;
65 66 67
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
Q
Qiao Longfei 已提交
68
  attr->set_f(scale);
69

Q
Qiao Longfei 已提交
70
  paddle::framework::OperatorPtr op =
71
      paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
72
  auto scope = std::make_shared<paddle::framework::Scope>();
Y
Yu Yang 已提交
73 74
  paddle::platform::CPUDeviceContext dev_ctx;
  op->Run(scope, dev_ctx);
Q
Qiao Longfei 已提交
75 76
  float scale_get = op->GetAttr<float>("scale");
  ASSERT_EQ(scale_get, scale);
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
}

TEST(OpRegistry, IllegalAttr) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
  op_desc.add_inputs("aa");
  op_desc.add_outputs("bb");

  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
  attr->set_f(-2.0);

  bool caught = false;
  try {
Q
Qiao Longfei 已提交
92
    paddle::framework::OperatorPtr op __attribute__((unused)) =
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
        paddle::framework::OpRegistry::CreateOp(op_desc);
  } catch (paddle::framework::EnforceNotMet err) {
    caught = true;
    std::string msg = "larger_than check fail";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);
}

TEST(OpRegistry, DefaultValue) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
  op_desc.add_inputs("aa");
  op_desc.add_outputs("bb");

Q
Qiao Longfei 已提交
111 112
  ASSERT_TRUE(op_desc.IsInitialized());

Q
Qiao Longfei 已提交
113
  paddle::framework::OperatorPtr op =
114
      paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
115
  auto scope = std::make_shared<paddle::framework::Scope>();
Y
Yu Yang 已提交
116 117
  paddle::platform::CPUDeviceContext dev_ctx;
  op->Run(scope, dev_ctx);
Q
Qiao Longfei 已提交
118
  ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
119 120
}

121 122 123 124 125 126 127 128
static void SetInputFormat(paddle::framework::OpDesc* desc) {
  auto attr = desc->add_attrs();
  attr->set_name("input_format");
  attr->set_type(paddle::framework::INTS);
  attr->mutable_ints()->Add(0);
  attr->mutable_ints()->Add(1);
}

129 130 131 132 133
TEST(OpRegistry, CustomChecker) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("my_test_op");
  op_desc.add_inputs("ii");
  op_desc.add_outputs("oo");
134
  SetInputFormat(&op_desc);
135 136 137 138

  // attr 'test_attr' is not set
  bool caught = false;
  try {
Q
Qiao Longfei 已提交
139
    paddle::framework::OperatorPtr op __attribute__((unused)) =
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        paddle::framework::OpRegistry::CreateOp(op_desc);
  } catch (paddle::framework::EnforceNotMet err) {
    caught = true;
    std::string msg = "Attribute 'test_attr' is required!";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);

  // set 'test_attr' set to an illegal value
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("test_attr");
  attr->set_type(paddle::framework::AttrType::INT);
  attr->set_i(3);
  caught = false;
  try {
Q
Qiao Longfei 已提交
158
    paddle::framework::OperatorPtr op __attribute__((unused)) =
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
        paddle::framework::OpRegistry::CreateOp(op_desc);
  } catch (paddle::framework::EnforceNotMet err) {
    caught = true;
    std::string msg = "'test_attr' must be even!";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);

  // set 'test_attr' set to a legal value
  op_desc.mutable_attrs()->Clear();
  attr = op_desc.mutable_attrs()->Add();
  attr->set_name("test_attr");
  attr->set_type(paddle::framework::AttrType::INT);
  attr->set_i(4);
176
  SetInputFormat(&op_desc);
Q
Qiao Longfei 已提交
177
  paddle::framework::OperatorPtr op =
178
      paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
179
  paddle::platform::CPUDeviceContext dev_ctx;
Y
Yu Yang 已提交
180
  auto scope = std::make_shared<paddle::framework::Scope>();
Y
Yu Yang 已提交
181
  op->Run(scope, dev_ctx);
Q
Qiao Longfei 已提交
182 183
  int test_attr = op->GetAttr<int>("test_attr");
  ASSERT_EQ(test_attr, 4);
D
dongzhihong 已提交
184
}
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216

class TestAttrProtoMaker : public pd::OpProtoAndCheckerMaker {
 public:
  TestAttrProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddAttr<float>("scale", "scale of test op");
    AddAttr<float>("scale", "scale of test op");
  }
};

TEST(ProtoMaker, DuplicatedAttr) {
  pd::OpProto op_proto;
  pd::OpAttrChecker op_checker;
  auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
  ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet);
}

class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
 public:
  TestInOutProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of test op");
    AddInput("input", "input of test op");
  }
};

TEST(ProtoMaker, DuplicatedInOut) {
  pd::OpProto op_proto;
  pd::OpAttrChecker op_checker;
  auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
  ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet);
}