提交 be441f7d 编写于 作者: Q Qiao Longfei 提交者: GitHub

test OpKernel (#2820)

Add unit test for OpKernel
上级 555b0a72
......@@ -12,7 +12,7 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry place)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
......
......@@ -147,13 +147,13 @@ class OpRegisterHelper {
}
};
#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \
class __op_class##Register { \
#define REGISTER_OP(type, op_class, op_maker_class) \
class op_class##Register { \
private: \
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \
const static OpRegisterHelper<op_class, op_maker_class> reg; \
}; \
const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type);
const OpRegisterHelper<op_class, op_maker_class> op_class##Register::reg( \
#type)
} // namespace framework
} // namespace paddle
......@@ -26,7 +26,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
};
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
REGISTER_OP(cos_sim, CosineOp, CosineOpProtoAndCheckerMaker);
class MyTestOp : public OperatorBase {
public:
......@@ -53,7 +53,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
};
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
REGISTER_OP(my_test_op, MyTestOp, MyTestOpProtoAndCheckerMaker);
} // namespace framework
} // namespace paddle
......
......@@ -45,7 +45,7 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
};
REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)
REGISTER_OP(test_operator, OperatorTest, OperatorTestProtoAndCheckerMaker);
TEST(OperatorBase, all) {
OpDesc op_desc;
......@@ -69,5 +69,55 @@ TEST(OperatorBase, all) {
delete op;
}
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("test_operator");
AddComment("This is test op");
}
};
class OpWithKernelTest : public OperatorWithKernel {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
};
class CPUKernelTest : public OpKernel {
public:
void Compute(const KernelContext& context) const {
float scale = context.op_.GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
std::cout << "this is cpu kernel" << std::endl;
std::cout << context.op_.DebugString() << std::endl;
}
};
REGISTER_OP(op_with_kernel, OpWithKernelTest, OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_KERNEL(op_with_kernel, platform::CPUPlace, CPUKernelTest);
TEST(OpKernel, all) {
OpDesc op_desc;
op_desc.set_type("op_with_kernel");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14);
platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
delete op;
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册