From 45eabb8cf23d6de3e7d3b62c78d3ab7ab1ebc7ce Mon Sep 17 00:00:00 2001 From: Cao Ying Date: Fri, 3 Nov 2017 17:33:20 -0500 Subject: [PATCH] Add the crf_decoding operator. (#5352) * proj init. * add unittest and implementation. --- paddle/operators/crf_decoding_op.cc | 136 ++++++++++++++++ paddle/operators/crf_decoding_op.h | 127 +++++++++++++++ paddle/operators/cross_entropy_op.cc | 5 +- paddle/operators/linear_chain_crf_op.cc | 65 ++++---- paddle/operators/linear_chain_crf_op.h | 4 +- .../framework/tests/test_crf_decoding_op.py | 146 ++++++++++++++++++ 6 files changed, 447 insertions(+), 36 deletions(-) create mode 100644 paddle/operators/crf_decoding_op.cc create mode 100644 paddle/operators/crf_decoding_op.h create mode 100644 python/paddle/v2/framework/tests/test_crf_decoding_op.py diff --git a/paddle/operators/crf_decoding_op.cc b/paddle/operators/crf_decoding_op.cc new file mode 100644 index 0000000000..d1ce74c4b9 --- /dev/null +++ b/paddle/operators/crf_decoding_op.cc @@ -0,0 +1,136 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/crf_decoding_op.h" + +namespace paddle { +namespace operators { +class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CRFDecodingOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Emission", + "(LoDTensor, default: LoDTensor). A LoDTensor with shape " + "[N x D] where N is the size of the mini-batch and D is the total " + "tag number. This input is the unscaled emission weight matrix of " + "the linear_chain_crf operator."); + AddInput( + "Transition", + "(Tensor, default: Tensor). A Tensor with shape [(D + 2) x D]. " + "This input is the transition weights learned by the linear_chain_crf " + "operator, denoted as w. The 1st row of w are transition weights for " + "the start mask. The 2nd row of w are transition weights for the end " + "mask. Transition weights between other tags begin from the 3rd row of " + "w. See more details in comments of the linear_chain_crf operator."); + AddInput( + "Label", + "(LoDTensor, LoDTensor). The ground truth with shape " + "[N x 1]. This input is optional. See more details in the operator's " + "comments.") + .AsDispensable(); + AddOutput("ViterbiPath", + "(LoDTensor, LoDTensor). The decoding results. What to " + "return changes depending on whether the Input(Label) (the groud " + "truth) is given. See more details in the operator's comment."); + AddComment(R"DOC( +The crf_decoding operator reads the emission feature weights and the transition +freature weights learned by the linear_chain_crf operator. It implements the +Viterbi algorithm which is a dynamic programming algorithm for finding the most +likely sequence of hidden states, called the Viterbi path, that results in a +sequence of observed tags. + +The output of this operator changes according to whether Input(Label) is given: + +1. Input(Label) is given: + +This happens in training. This operator is used to co-work with the chunk_eval +operator. + +When Input(Label) is given, the crf_decoding operator returns a row vector +with shape [N x 1] whose values are fixed to be 0, indicating an incorrect +prediction, or 1 indicating a tag is correctly predicted. Such an ouput is the +input to chunk_eval operator. + +2. Input(Label) is not given: + +This is the standard decoding process. + +The crf_decoding operator returns a row vecotr with shape [N x 1] whose values +range from 0 to maximum tag number - 1. Each element indicates an index of a +predicted tag. +)DOC"); + } +}; + +class CRFDecodingOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Emission"), + "Input(Emission) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Transition"), + "Input(Transition) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("ViterbiPath"), + "Output(ViterbiPath) should be not null."); + + auto emission_dims = ctx->GetInputDim("Emission"); + PADDLE_ENFORCE_EQ(emission_dims.size(), 2UL, + "The Input(Emission) should be a 2-D tensor."); + PADDLE_ENFORCE(emission_dims[0], "An empty mini-batch is not allowed."); + + auto transition_dims = ctx->GetInputDim("Transition"); + PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL, + "The Input(Transition) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + transition_dims[0] - 2, transition_dims[1], + "An invalid dimension for the Input(Transition), which should " + "be a 2-D tensor with shape [(D + 2) x D]."); + PADDLE_ENFORCE_EQ( + emission_dims[1], transition_dims[1], + "The 2nd dimension of the Input(Emission) and the Input(Transition) " + "should be equal to the tag number."); + + if (ctx->HasInput("Label")) { + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL, + "The Input(Label) should be a 2-D tensor with the 2nd " + "dimensions fixed to 1."); + PADDLE_ENFORCE_EQ( + emission_dims[0], label_dims[0], + "The height of Input(Emission) and the height of Input(Label) " + "should be the same."); + } + + ctx->ShareLoD("Emission", /*->*/ "ViterbiPath"); + ctx->SetOutputDim("ViterbiPath", {emission_dims[0], 1}); + } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("Emission")->type()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(crf_decoding, ops::CRFDecodingOp, + ops::CRFDecodingOpMaker); +REGISTER_OP_CPU_KERNEL( + crf_decoding, ops::CRFDecodingOpKernel, + ops::CRFDecodingOpKernel); diff --git a/paddle/operators/crf_decoding_op.h b/paddle/operators/crf_decoding_op.h new file mode 100644 index 0000000000..526e0c5dcb --- /dev/null +++ b/paddle/operators/crf_decoding_op.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::LoD; +using framework::Tensor; + +template +class CRFDecodingOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "The crf_decoding operator can only run on CPU."); + + auto* emission_weights = ctx.Input("Emission"); + auto* transition_weights = ctx.Input("Transition"); + auto* label = ctx.Input("Label"); + auto* decoded_path = ctx.Output("ViterbiPath"); + + PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL, + "The Input(Emission) should be a sequence."); + auto lod = emission_weights->lod(); + PADDLE_ENFORCE(lod.size(), "Input(Emission) must be a sequence."); + const size_t level = 0; + const size_t seq_num = lod[level].size() - 1; + + int* path = decoded_path->mutable_data(platform::CPUPlace()); + math::SetConstant()(ctx.device_context(), + decoded_path, 0); + for (size_t i = 0; i < seq_num; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + Tensor decoded_path_one_seq = decoded_path->Slice(start_pos, end_pos); + Decode(emission_weights->Slice(start_pos, end_pos), *transition_weights, + &decoded_path_one_seq); + } + + if (label) { + PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL, + "The Input(Label) should be a sequence."); + const int* label_value = label->data(); + size_t batch_size = emission_weights->dims()[0]; + for (size_t i = 0; i < batch_size; ++i) { + path[i] = label_value[i] == path[i] ? 1 : 0; + } + } + } + + private: + void Decode(const Tensor& emission_weights, const Tensor& transition_weights, + Tensor* decoded_path) const { + auto emission_dims = emission_weights.dims(); + const size_t seq_len = emission_dims[0]; + const size_t tag_num = emission_dims[1]; + + const size_t state_trans_base_idx = 2; + + const T* x = emission_weights.data(); + const T* w = transition_weights.data(); + int* path = decoded_path->data(); + + // alpha is a memo table. An element alpha(k, v) records the score of the + // best sequence of tags from position 1 to position k with v being the end + // tag. + Tensor alpha; + T* alpha_value = alpha.mutable_data(emission_dims, platform::CPUPlace()); + Tensor track; + int* track_value = + track.mutable_data(emission_dims, platform::CPUPlace()); + + for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i]; + + for (size_t k = 1; k < seq_len; ++k) { + for (size_t i = 0; i < tag_num; ++i) { + T max_score = -std::numeric_limits::max(); + int max_j = 0; + for (size_t j = 0; j < tag_num; ++j) { + T score = alpha_value[(k - 1) * tag_num + j] + + w[(j + state_trans_base_idx) * tag_num + i]; + if (score > max_score) { + max_score = score; + max_j = j; + } + } + + alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i]; + track_value[k * tag_num + i] = max_j; + } + } + + T max_score = -std::numeric_limits::max(); + int max_i = 0; + for (size_t i = 0; i < tag_num; ++i) { + T score = alpha_value[(seq_len - 1) * tag_num + i] + w[tag_num + i]; + if (score > max_score) { + max_score = score; + max_i = i; + } + } + path[seq_len - 1] = max_i; + for (int k = seq_len - 1; k >= 1; --k) { + path[k - 1] = max_i = track_value[k * tag_num + max_i]; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 3ed41933b1..24df1fcada 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -49,7 +49,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel { } protected: - // Explicitly set that data type of the output of the cross_entropy operator + // Explicitly set that the data type of computation kernel of cross_entropy // is determined by its input "X". framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { @@ -96,7 +96,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { } protected: - // CrossEntropy's data type just determined by "X" + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { return framework::ToDataType(ctx.Input("X")->type()); diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc index 605dbba5af..6864e3b0b7 100644 --- a/paddle/operators/linear_chain_crf_op.cc +++ b/paddle/operators/linear_chain_crf_op.cc @@ -22,43 +22,44 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { LinearChainCRFOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Emission", - "(LoDTensor, default: LoDTensor). " - "The unscaled emission weight matrix for the linear chain CRF. " - "This input is a LoDTensor with shape [N x D] where N is the size of " - "the mini-batch and D is the total tag number."); - AddInput( - "Transition", - "(Tensor, default: Tensor). A Tensor with shape [(D + 2) x D]. " - "The learnable parameter for the linear_chain_crf operator. " - "See more details in the operator's comments."); - AddInput( - "Label", - "(LoDTensor, default: LoDTensor). The ground truth which is a 2-D " - "LoDTensor with shape [N x 1], where N is the total element number in " - "a mini-batch."); + AddInput("Emission", + "(LoDTensor, default: LoDTensor). " + "A 2-D LoDTensor with shape [N x D] where N is the size of the " + "mini-batch and D is the total tag number. The unscaled emission " + "weight matrix for the linear chain CRF. "); + AddInput("Transition", + "(Tensor, default: Tensor). A 2-D Tensor with shape " + "[(D + 2) x D]. The learnable parameter for the linear_chain_crf " + "operator. See more details in the operator's comments."); + AddInput("Label", + "(LoDTensor, default: LoDTensor). A LoDTensor with shape " + "[N x 1], where N is the total element number in a mini-batch. " + "The ground truth."); AddOutput( "Alpha", - "Tensor, default: Tensor. The forward vectors for the entire " - "batch. A two dimensional tensor with shape [N x D], " - "denoted as \f$\alpha\f$. \f$\alpha$\f is a memo table used to " - "calculate the normalization factor in CRF. \f$\alpha[k, v]$\f stores " - "the unnormalized probabilites of all possible unfinished sequences of " - "tags that end at position \f$k$\f with tag \f$v$\f. For each \f$k$\f, " + "(Tensor, default: Tensor). A 2-D Tensor with shape [N x D]. " + "The forward vectors for the entire batch. Denote it as \f$\alpha\f$. " + "\f$\alpha$\f is a memo table used to calculate the normalization " + "factor in CRF. \f$\alpha[k, v]$\f stores the unnormalized " + "probabilites of all possible unfinished sequences of tags that end at " + "position \f$k$\f with tag \f$v$\f. For each \f$k$\f, " "\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for " "each tag value \f$v$\f. This vector is called a forward vecotr and " "will also be used in backward computations.") .AsIntermediate(); - AddOutput("EmissionExps", - "The exponentials of Input(Emission). This is an intermediate " - "computational result in forward computation, and will be reused " - "in backward computation.") + AddOutput( + "EmissionExps", + "(Tensor, default: Tensor). A 2-D Tensor with shape [N x D]. " + "The exponentials of Input(Emission). This is an intermediate " + "computational result in forward computation, and will be reused in " + "backward computation.") .AsIntermediate(); - AddOutput("TransitionExps", - "The exponentials of Input(Transition). This is an intermediate " - "computational result in forward computation, and will be reused " - "in backward computation.") + AddOutput( + "TransitionExps", + "(Tensor, default: Tensor). A 2-D Tensor with shape " + "[(D + 2) x D]. The exponentials of Input(Transition). This is an " + "intermediate computational result in forward computation, and " + "will be reused in backward computation.") .AsIntermediate(); AddOutput( "LogLikelihood", @@ -179,8 +180,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { } protected: - // Explicitly set that the data type of output of the linear_chain_crf - // operator is determined by its input "Emission". + // Explicitly set that the data type of computation kernel of linear_chain_crf + // is determined by its input "Emission". framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { return framework::ToDataType(ctx.Input("Emission")->type()); diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h index 56fb0c9102..ddf7398175 100644 --- a/paddle/operators/linear_chain_crf_op.h +++ b/paddle/operators/linear_chain_crf_op.h @@ -134,7 +134,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel { Tensor emission_row_max; emission_row_max.mutable_data( - framework::make_ddim({static_cast(batch_size), 1}), + framework::make_ddim({static_cast(batch_size), 1}), platform::CPUPlace()); auto place = ctx.GetEigenDevice(); @@ -273,7 +273,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel { const int* lbl = label.data(); PADDLE_ENFORCE_LT( - *std::max_element(lbl, lbl + seq_length), tag_num, + static_cast(*std::max_element(lbl, lbl + seq_length)), tag_num, "An invalid tag label that execesses the largest tag number."); // Calculate the nominator part, which depends on the label sequence. diff --git a/python/paddle/v2/framework/tests/test_crf_decoding_op.py b/python/paddle/v2/framework/tests/test_crf_decoding_op.py new file mode 100644 index 0000000000..ee2b996bf4 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_crf_decoding_op.py @@ -0,0 +1,146 @@ +import unittest +import random +import numpy as np + +from op_test import OpTest + + +class CRFDecoding(object): + def __init__(self, emission_weights, transition_weights, + seq_start_positions): + assert (emission_weights.shape[0] == seq_start_positions[-1]) + self.tag_num = emission_weights.shape[1] + self.seq_num = len(seq_start_positions) - 1 + + self.seq_start_positions = seq_start_positions + self.x = emission_weights + + self.a = transition_weights[0, :] + self.b = transition_weights[1, :] + self.w = transition_weights[2:, :] + + self.track = np.zeros( + (seq_start_positions[-1], self.tag_num), dtype="int32") + self.decoded_path = np.zeros( + (seq_start_positions[-1], 1), dtype="int32") + + def _decode_one_sequence(self, decoded_path, x): + seq_len, tag_num = x.shape + alpha = np.zeros((seq_len, tag_num), dtype="float64") + track = np.zeros((seq_len, tag_num), dtype="int32") + + for i in range(tag_num): + alpha[0, i] = self.a[i] + x[0, i] + + for k in range(1, seq_len): + for i in range(tag_num): + max_score = -np.finfo("float64").max + max_idx = 0 + for j in range(tag_num): + score = alpha[k - 1, j] + self.w[j, i] + if score > max_score: + max_score = score + max_idx = j + alpha[k, i] = max_score + x[k, i] + track[k, i] = max_idx + + max_score = -np.finfo("float64").max + max_idx = 0 + for i in range(tag_num): + score = alpha[seq_len - 1, i] + self.b[i] + if score > max_score: + max_score = score + max_idx = i + + decoded_path[-1] = max_idx + for i in range(seq_len - 1, 0, -1): + decoded_path[i - 1] = max_idx = track[i, max_idx] + + def decode(self): + for i in range(self.seq_num): + start = self.seq_start_positions[i] + end = self.seq_start_positions[i + 1] + self._decode_one_sequence(self.decoded_path[start:end, :], + self.x[start:end, :]) + return self.decoded_path + + +class TestCRFDecodingOp1(OpTest): + """ + Compare the dynamic program with random generated parameters and inputs + with grouth truth not being given. + """ + + def set_test_data(self): + SEQ_NUM = 3 + TAG_NUM = 17 + MAX_SEQ_LEN = 10 + + lod = [[0]] + for i in range(SEQ_NUM): + lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN)) + emission = np.random.uniform(-1, 1, + [lod[-1][-1], TAG_NUM]).astype("float64") + transition = np.random.uniform(-0.5, 0.5, + [TAG_NUM + 2, TAG_NUM]).astype("float64") + + self.inputs = { + "Emission": (emission, lod), + "Transition": transition, + } + + decoder = CRFDecoding(emission, transition, lod[0]) + decoded_path = decoder.decode() + + self.outputs = {"ViterbiPath": decoded_path} + + def setUp(self): + self.op_type = "crf_decoding" + self.set_test_data() + + def test_check_output(self): + self.check_output() + + +class TestCRFDecodingOp2(OpTest): + """ + Compare the dynamic program with brute force computation with + ground truth being given. + """ + + def setUp(self): + self.op_type = "crf_decoding" + TAG_NUM = 5 + + lod = [[0, 1, 3, 6, 10]] + transition = np.repeat( + np.arange( + TAG_NUM, dtype="float64").reshape(1, TAG_NUM), + TAG_NUM + 2, + axis=0) + emission = np.repeat( + np.arange( + TAG_NUM, dtype="float64").reshape(1, TAG_NUM), + lod[-1][-1], + axis=0) + + labels = np.random.randint( + low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32") + predicted_labels = np.ones( + (lod[-1][-1], 1), dtype="int32") * (TAG_NUM - 1) + expected_output = (labels == predicted_labels).astype("int32") + + self.inputs = { + "Emission": (emission, lod), + "Transition": transition, + "Label": (labels, lod) + } + + self.outputs = {"ViterbiPath": expected_output} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() -- GitLab