op_registry_test.cc 5.3 KB
Newer Older
1 2
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
Q
Qiao Longfei 已提交
3 4 5 6
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"

using namespace paddle::framework;
7

8 9
namespace paddle {
namespace framework {
Q
Qiao Longfei 已提交
10
class CosineOp : public OperatorWithKernel {
11
 public:
Q
Qiao Longfei 已提交
12 13
  void Run(const OpRunContext* context) const override {
    printf("%s\n", DebugString().c_str());
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
  }
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of cosine op");
    AddOutput("output", "output of cosine op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
        .LargerThan(0.0);
    AddType("cos");
    AddComment("This is cos op");
  }
};

REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)

Q
Qiao Longfei 已提交
33
class MyTestOp : public OperatorWithKernel {
34
 public:
Q
Qiao Longfei 已提交
35 36 37
  void Run(const OpRunContext* ctx) const override {
    printf("%s\n", DebugString().c_str());
    printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
  }
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of cosine op");
    AddOutput("output", "output of cosine op");
    auto my_checker = [](int i) {
      PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
    };
    AddAttr<int>("test_attr", "a simple test attribute")
        .AddCustomChecker(my_checker);
    AddType("my_test_op");
    AddComment("This is my_test op");
  }
};

REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
}  // namespace framework
}  // namespace paddle

61 62 63 64 65 66
TEST(OpRegistry, CreateOp) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
  op_desc.add_inputs("aa");
  op_desc.add_outputs("bb");

Q
Qiao Longfei 已提交
67
  float scale = 3.3;
68 69 70
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
Q
Qiao Longfei 已提交
71
  attr->set_f(scale);
72

Q
Qiao Longfei 已提交
73
  paddle::framework::OperatorBase* op =
74
      paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
75 76 77 78 79
  auto scope = std::make_shared<Scope>();
  auto dev_ctx = DeviceContext();
  op->Run(scope, &dev_ctx);
  float scale_get = op->GetAttr<float>("scale");
  ASSERT_EQ(scale_get, scale);
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
}

TEST(OpRegistry, IllegalAttr) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
  op_desc.add_inputs("aa");
  op_desc.add_outputs("bb");

  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
  attr->set_f(-2.0);

  bool caught = false;
  try {
Q
Qiao Longfei 已提交
95
    paddle::framework::OperatorBase* op __attribute__((unused)) =
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        paddle::framework::OpRegistry::CreateOp(op_desc);
  } catch (paddle::framework::EnforceNotMet err) {
    caught = true;
    std::string msg = "larger_than check fail";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);
}

TEST(OpRegistry, DefaultValue) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("cos_sim");
  op_desc.add_inputs("aa");
  op_desc.add_outputs("bb");

Q
Qiao Longfei 已提交
114 115 116
  ASSERT_TRUE(op_desc.IsInitialized());

  paddle::framework::OperatorBase* op =
117
      paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
118 119 120 121
  auto scope = std::make_shared<Scope>();
  auto dev_ctx = DeviceContext();
  op->Run(scope, &dev_ctx);
  ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
122 123 124 125 126 127 128 129 130 131 132
}

TEST(OpRegistry, CustomChecker) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("my_test_op");
  op_desc.add_inputs("ii");
  op_desc.add_outputs("oo");

  // attr 'test_attr' is not set
  bool caught = false;
  try {
Q
Qiao Longfei 已提交
133
    paddle::framework::OperatorBase* op __attribute__((unused)) =
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
        paddle::framework::OpRegistry::CreateOp(op_desc);
  } catch (paddle::framework::EnforceNotMet err) {
    caught = true;
    std::string msg = "Attribute 'test_attr' is required!";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);

  // set 'test_attr' set to an illegal value
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("test_attr");
  attr->set_type(paddle::framework::AttrType::INT);
  attr->set_i(3);
  caught = false;
  try {
Q
Qiao Longfei 已提交
152
    paddle::framework::OperatorBase* op __attribute__((unused)) =
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
        paddle::framework::OpRegistry::CreateOp(op_desc);
  } catch (paddle::framework::EnforceNotMet err) {
    caught = true;
    std::string msg = "'test_attr' must be even!";
    const char* err_msg = err.what();
    for (size_t i = 0; i < msg.length(); ++i) {
      ASSERT_EQ(err_msg[i], msg[i]);
    }
  }
  ASSERT_TRUE(caught);

  // set 'test_attr' set to a legal value
  op_desc.mutable_attrs()->Clear();
  attr = op_desc.mutable_attrs()->Add();
  attr->set_name("test_attr");
  attr->set_type(paddle::framework::AttrType::INT);
  attr->set_i(4);
Q
Qiao Longfei 已提交
170
  paddle::framework::OperatorBase* op =
171
      paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
172 173 174 175 176
  auto dev_ctx = DeviceContext();
  auto scope = std::make_shared<Scope>();
  op->Run(scope, &dev_ctx);
  int test_attr = op->GetAttr<int>("test_attr");
  ASSERT_EQ(test_attr, 4);
D
dongzhihong 已提交
177
}
178 179 180 181 182

int main(int argc, char** argv) {
  testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}