From 735737d28369d6040d0bacbae9973052e51cd7af Mon Sep 17 00:00:00 2001 From: caoying03 Date: Fri, 29 Sep 2017 21:33:19 +0800 Subject: [PATCH] initialize crf opreator. --- paddle/operators/crf_op.cc | 48 +++++++++++++++++++ paddle/operators/crf_op.h | 41 ++++++++++++++++ .../paddle/v2/framework/tests/test_crf_op.py | 13 +++++ 3 files changed, 102 insertions(+) create mode 100644 paddle/operators/crf_op.cc create mode 100644 paddle/operators/crf_op.h create mode 100644 python/paddle/v2/framework/tests/test_crf_op.py diff --git a/paddle/operators/crf_op.cc b/paddle/operators/crf_op.cc new file mode 100644 index 00000000000..21ffcf48c0c --- /dev/null +++ b/paddle/operators/crf_op.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/crf_op.h" + +namespace paddle { +namespace operators { + +class CrfOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CrfOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) {} +}; + +class CrfOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override {} +}; + +class CrfGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(crf, ops::CrfOp, ops::CrfOpMaker, crf_grad, ops::CrfGradOp); +REGISTER_OP_CPU_KERNEL(crf, ops::CrfOpKernel); +REGISTER_OP_CPU_KERNEL(crf_grad, ops::CrfGradOpKernel); diff --git a/paddle/operators/crf_op.h b/paddle/operators/crf_op.h new file mode 100644 index 00000000000..cb34c5c6a34 --- /dev/null +++ b/paddle/operators/crf_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class CrfOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + } +}; + +template +class CrfGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_crf_op.py b/python/paddle/v2/framework/tests/test_crf_op.py new file mode 100644 index 00000000000..47c9341fa0b --- /dev/null +++ b/python/paddle/v2/framework/tests/test_crf_op.py @@ -0,0 +1,13 @@ +import unittest +import numpy as np + + +class TestCrfOp(OpTest): + def setUp(self): + self.op_type = "crf" + batch_size = 3 + class_num = 37 + + +if __name__ == "__main__": + unittest.main() -- GitLab