diff --git a/paddle/operators/crf_op.cc b/paddle/operators/crf_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..21ffcf48c0c267ff1acfd28de00d742f5ff8b16a --- /dev/null +++ b/paddle/operators/crf_op.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/crf_op.h" + +namespace paddle { +namespace operators { + +class CrfOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CrfOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) {} +}; + +class CrfOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override {} +}; + +class CrfGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(crf, ops::CrfOp, ops::CrfOpMaker, crf_grad, ops::CrfGradOp); +REGISTER_OP_CPU_KERNEL(crf, ops::CrfOpKernel); +REGISTER_OP_CPU_KERNEL(crf_grad, ops::CrfGradOpKernel); diff --git a/paddle/operators/crf_op.h b/paddle/operators/crf_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cb34c5c6a34b707079a9ffe90f2e5fb6d8f70098 --- /dev/null +++ b/paddle/operators/crf_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class CrfOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + } +}; + +template +class CrfGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_crf_op.py b/python/paddle/v2/framework/tests/test_crf_op.py new file mode 100644 index 0000000000000000000000000000000000000000..47c9341fa0bcfa6e79c95ee112612e59ae863119 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_crf_op.py @@ -0,0 +1,13 @@ +import unittest +import numpy as np + + +class TestCrfOp(OpTest): + def setUp(self): + self.op_type = "crf" + batch_size = 3 + class_num = 37 + + +if __name__ == "__main__": + unittest.main()