diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index aa46829fdde82b58a649108bf708901299cd8153..3be26fdc4fb6ebdd0ec427a2248b0f97d9edff01 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -37,32 +37,32 @@ ExecutionContext::GetEigenDevice() const { std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, - "Op %s input %s should contain only one variable", type_, - name); + "Operator %s's input %s should contain only one variable.", + type_, name); return ins.empty() ? kEmptyVarName : ins[0]; } const std::vector& OperatorBase::Inputs( const std::string& name) const { auto it = inputs_.find(name); - PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_, - name); + PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.", + type_, name); return it->second; } std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); PADDLE_ENFORCE_LE(outs.size(), 1UL, - "Op %s output %s should contain only one variable", type_, - name); + "Operator %s's output %s should contain only one variable.", + type_, name); return outs.empty() ? kEmptyVarName : outs[0]; } const std::vector& OperatorBase::Outputs( const std::string& name) const { auto it = outputs_.find(name); - PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output called %s", - type_, name); + PADDLE_ENFORCE(it != outputs_.end(), + "Operator %s does not have an output called %s.", type_, name); return it->second; } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 93885fa3028e072bc0bd021ea9287087678f3621..b8a7040ed024fc7b19980beef3d8b367dfdd7f50 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -427,7 +427,8 @@ class OperatorWithKernel : public OperatorBase { int tmp = static_cast(ToDataType(t->type())); VLOG(3) << "Input " << ipt_name << " with data_type " << tmp; PADDLE_ENFORCE(tmp == data_type || data_type == -1, - "DataType of Paddle Op %s must be same.", Type()); + "DataType of Paddle Op %s must be the same.", + Type()); data_type = tmp; } } diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 7b9a5b75e1087a1cc3b6c6c7a6e4dc185c32dd42..9eab67561a42b3fb4e22d8475ad5eeb146a72f1c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -118,10 +118,12 @@ class Tensor { const platform::DeviceContext& ctx); /** - * @brief Return the slice of the tensor. + * @brief Return a sub-tensor of the given tensor. * - * @param[in] begin_idx The begin index of the slice. - * @param[in] end_idx The end index of the slice. + * @param[in] begin_idx The index of the start row(inclusive) to slice. + * The index number begins from 0. + * @param[in] end_idx The index of the end row(exclusive) to slice. + * The index number begins from 0. */ inline Tensor Slice(const int& begin_idx, const int& end_idx) const; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 29ac683f48fcde4dd3b5ad7f04b5d1d7434706ba..bcccdd5881775e199297dce7e70aaf6aae62d95a 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -112,9 +112,10 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { if (holder_ != nullptr) { holder_->set_type(type); } - PADDLE_ENFORCE_GT(numel(), 0, - "Tensor's numel must be larger than zero to call " - "Tensor::mutable_data. Call Tensor::set_dim first."); + PADDLE_ENFORCE_GT( + numel(), 0, + "When calling this method, the Tensor's numel must be larger than zero. " + "Please check Tensor::Resize has been called first."); int64_t size = numel() * SizeOfType(type); /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || !(holder_->place() == place) || @@ -229,10 +230,12 @@ inline void Tensor::CopyFromVector(const std::vector& src, inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { check_memory_size(); - PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero."); - PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound."); - PADDLE_ENFORCE_LT(begin_idx, end_idx, - "Begin index must be less than end index."); + PADDLE_ENFORCE_GE(begin_idx, 0, + "The start row index must be greater than 0."); + PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound."); + PADDLE_ENFORCE_LT( + begin_idx, end_idx, + "The start row index must be lesser than the end row index."); if (dims_[0] == 1) { return *this; diff --git a/paddle/gserver/layers/CRFLayer.cpp b/paddle/gserver/layers/CRFLayer.cpp index 0b544420097e9150f8489731b6379dea633e992c..867303b4fa0d490297ab152fc2ad266e92e29baf 100644 --- a/paddle/gserver/layers/CRFLayer.cpp +++ b/paddle/gserver/layers/CRFLayer.cpp @@ -101,8 +101,10 @@ void CRFLayer::backward(const UpdateCallback& callback) { : real(1.0f); instanceWeight *= coeff_; - MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]); - grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight); + if (output.grad) { + MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]); + grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight); + } if (needWGrad) { weight_->getWGrad()->add( *crfs_[i].getWGrad(), real(1.0f), instanceWeight); diff --git a/paddle/gserver/layers/LinearChainCRF.cpp b/paddle/gserver/layers/LinearChainCRF.cpp index dc3dc156792bdf32c3b948a292597d0e9eca5d8b..abaa1802b763a49f748214dbd4dec1d2bac53b59 100644 --- a/paddle/gserver/layers/LinearChainCRF.cpp +++ b/paddle/gserver/layers/LinearChainCRF.cpp @@ -102,7 +102,6 @@ real LinearChainCRF::forward(real* x, int* s, int length) { } void LinearChainCRF::backward(real* x, int* s, int length, bool needWGrad) { - MatrixPtr matX = Matrix::create(x, length, numClasses_); Matrix::resizeOrCreate(matGrad_, length, numClasses_); Matrix::resizeOrCreate(beta_, length, numClasses_); real* b = b_->getData(); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index d94b96200c2a5cd112b17e45aa6cd4a63bdd04d0..39df19da677a7dee7d0989d491f8d5511f73a9c7 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -28,8 +28,9 @@ class CrossEntropyOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], "The 1st dimension of Input(X) and Input(Label) should " "be equal."); @@ -38,8 +39,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { "If Attr(soft_label) == true, the 2nd dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "If Attr(soft_label) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[1], 1UL, + "If Attr(softLabel) == false, the 2nd dimension of " "Input(Label) should be 1."); } @@ -48,7 +49,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { } protected: - // CrossEntropy's data type just determined by "X" + // Explicitly set that data type of the output of the cross_entropy operator + // is determined by its input "X". framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { return framework::ToDataType(ctx.Input("X")->type()); diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..605dbba5af1bb8b0d718833be6af45fdaeac70ac --- /dev/null +++ b/paddle/operators/linear_chain_crf_op.cc @@ -0,0 +1,261 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/linear_chain_crf_op.h" + +namespace paddle { +namespace operators { + +class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LinearChainCRFOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Emission", + "(LoDTensor, default: LoDTensor). " + "The unscaled emission weight matrix for the linear chain CRF. " + "This input is a LoDTensor with shape [N x D] where N is the size of " + "the mini-batch and D is the total tag number."); + AddInput( + "Transition", + "(Tensor, default: Tensor). A Tensor with shape [(D + 2) x D]. " + "The learnable parameter for the linear_chain_crf operator. " + "See more details in the operator's comments."); + AddInput( + "Label", + "(LoDTensor, default: LoDTensor). The ground truth which is a 2-D " + "LoDTensor with shape [N x 1], where N is the total element number in " + "a mini-batch."); + AddOutput( + "Alpha", + "Tensor, default: Tensor. The forward vectors for the entire " + "batch. A two dimensional tensor with shape [N x D], " + "denoted as \f$\alpha\f$. \f$\alpha$\f is a memo table used to " + "calculate the normalization factor in CRF. \f$\alpha[k, v]$\f stores " + "the unnormalized probabilites of all possible unfinished sequences of " + "tags that end at position \f$k$\f with tag \f$v$\f. For each \f$k$\f, " + "\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for " + "each tag value \f$v$\f. This vector is called a forward vecotr and " + "will also be used in backward computations.") + .AsIntermediate(); + AddOutput("EmissionExps", + "The exponentials of Input(Emission). This is an intermediate " + "computational result in forward computation, and will be reused " + "in backward computation.") + .AsIntermediate(); + AddOutput("TransitionExps", + "The exponentials of Input(Transition). This is an intermediate " + "computational result in forward computation, and will be reused " + "in backward computation.") + .AsIntermediate(); + AddOutput( + "LogLikelihood", + "(Tensor, default: Tensor). The logarithm of the conditional " + "likelihood of each training sample in a mini-batch. This is a 2-D " + "tensor with shape [S x 1], where S is the sequence number in a " + "mini-batch. Note: S is equal to the sequence number in a mini-batch. " + "The output is no longer a LoDTensor."); + AddComment(R"DOC( +Conditional Random Field defines an undirected probabilistic graph with nodes +denoting random variables and edges denoting dependencies between these +variables. CRF learns the conditional probability \f$P(Y|X)\f$, where +\f$X = (x_1, x_2, ... , x_n)\f$ are structured inputs and +\f$Y = (y_1, y_2, ... , y_n)\f$ are labels for the inputs. + +Linear chain CRF is a special case of CRF that is useful for sequence labeling +task. Sequence labeling tasks do not assume a lot of conditional +independences among inputs. The only constraint they impose is that the input +and output must be linear sequences. Thus, the graph of such a CRF is a simple +chain or a line, which results in the linear chain CRF. + +This operator implements the Forward-Backward algorithm for the linear chain +CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf and +http://cseweb.ucsd.edu/~elkan/250Bwinter2012/loglinearCRFs.pdf for reference. + +Equation: + +- Denote Input(Emission) to this operator as \f$x\f$ here. +- The first D values of Input(Transition) to this operator are for starting +weights, denoted as \f$a\f$ here. +- The next D values of Input(Transition) of this operator are for ending +weights, denoted as \f$b\f$ here. +- The remaning values of Input(Transition) are for transition weights, +denoted as \f$w\f$ here. +- Denote Input(Label) as \f$s\f$ here. + +The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as: +\f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L} + + \sum_{l=1}^L x_{s_l} + + \sum_{l=2}^L w_{s_{l-1},s_l})\f$ +where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over +all possible sequences is \f$1\f$, and \f$x\f$ is the emission feature weight +to the linear chain CRF. + +Finaly, the linear chain CRF operator outputs the logarithm of the conditional +likelihood of each training sample in a mini-batch. + +NOTE: +1. The feature function for a CRF is made up of the emission features and the +transition features. The emission feature weights are NOT computed in +this operator. They MUST be computed first before this operator is called. + +2. Because this operator performs global normalization over all possible +sequences internally, it expects UNSCALED emission feature weights. +Please do not call this op with the emission feature being output of any +nonlinear activation. + +3. The 2nd dimension of Input(Emission) MUST be equal to the tag number. + +)DOC"); + } +}; + +class LinearChainCRFOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Emission"), + "Input(Emission) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Transition"), + "Input(Transition) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("Alpha"), + "Output(Alpha) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("EmissionExps"), + "Output(EmissionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("TransitionExps"), + "Output(TransitionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"), + "Output(LogLikelihood) should be not null."); + + auto emission_dims = ctx->GetInputDim("Emission"); + PADDLE_ENFORCE_EQ(emission_dims.size(), 2UL, + "The Input(Emission) should be a 2-D tensor."); + PADDLE_ENFORCE(emission_dims[0], "An empty mini-batch is not allowed."); + + auto transition_dims = ctx->GetInputDim("Transition"); + PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL, + "The Input(Transition) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + transition_dims[0] - 2, transition_dims[1], + "An invalid dimension for the Input(Transition), which should " + "be a 2-D tensor with shape [(D + 2) x D]."); + PADDLE_ENFORCE_EQ( + emission_dims[1], transition_dims[1], + "The 2nd dimension of the Input(Emission) and the Input(Transition) " + "should be equal to the tag number."); + + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL, + "The Input(Label) should be a 2-D tensor with the 2nd " + "dimensions fixed to 1."); + PADDLE_ENFORCE_EQ( + emission_dims[0], label_dims[0], + "The height of Input(Emission) and the height of Input(Label) " + "should be the same."); + + ctx->SetOutputDim("Alpha", emission_dims); + ctx->SetOutputDim("EmissionExps", emission_dims); + ctx->SetOutputDim("TransitionExps", transition_dims); + // TODO(caoying) This is tricky. The 1st dimension of Output(LogLikelihood) + // is the sequence number in a mini-batch. The dimension set here should be + // resized to its correct size in the function Compute. Fix this once we can + // get LoD information in the InferShape interface. + ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1}); + } + + protected: + // Explicitly set that the data type of output of the linear_chain_crf + // operator is determined by its input "Emission". + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("Emission")->type()); + } +}; + +class LinearChainCRFGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("EmissionExps"), + "Input(EmissionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("TransitionExps"), + "Input(TransitionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")), + "Input(LogLikelihood@GRAD) shoudl be not null."); + + auto emission_exps_dims = ctx->GetInputDim("EmissionExps"); + PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2UL, + "The Input(EmissionExps) should be a 2-D tensor."); + PADDLE_ENFORCE(emission_exps_dims[0], + "An empty mini-batch is not allowed."); + + auto transition_exps_dims = ctx->GetInputDim("TransitionExps"); + PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2UL, + "The Input(TransitionExps) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + transition_exps_dims[0] - 2, transition_exps_dims[1], + "An invalid dimension for the Input(TransitionExps), which should " + "be a 2-D tensor with shape [(D + 2) x D]."); + PADDLE_ENFORCE_EQ( + emission_exps_dims[1], transition_exps_dims[1], + "The 2nd dimension of the Input(EmissionExps) and the " + "Input(TransitionExps) should be equal to the tag number."); + + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL, + "The Input(Label) should be a 2-D tensor with the 2nd " + "dimensions fixed to 1."); + PADDLE_ENFORCE_EQ( + emission_exps_dims[0], label_dims[0], + "The height of Input(EmissionExps) and the height of Input(Label) " + "should be the same."); + + if (ctx->HasOutput(framework::GradVarName("Emission"))) { + ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims); + } + if (ctx->HasOutput(framework::GradVarName("Transition"))) { + ctx->SetOutputDim(framework::GradVarName("Transition"), + transition_exps_dims); + } + } + + protected: + // Explicitly set that the data type of output of the linear_chain_crf_grad + // operator is determined by its input: gradients of LogLikelihood. + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType( + ctx.Input(framework::GradVarName("LogLikelihood"))->type()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(linear_chain_crf, ops::LinearChainCRFOp, ops::LinearChainCRFOpMaker, + linear_chain_crf_grad, ops::LinearChainCRFGradOp); +REGISTER_OP_CPU_KERNEL( + linear_chain_crf, + ops::LinearChainCRFOpKernel, + ops::LinearChainCRFOpKernel); +REGISTER_OP_CPU_KERNEL( + linear_chain_crf_grad, + ops::LinearChainCRFGradOpKernel, + ops::LinearChainCRFGradOpKernel); diff --git a/paddle/operators/linear_chain_crf_op.cu b/paddle/operators/linear_chain_crf_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6fc8995f4c2ce05f89ffb58129695113f89159fa --- /dev/null +++ b/paddle/operators/linear_chain_crf_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/linear_chain_crf_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + linear_chain_crf, + ops::LinearChainCRFOpKernel, + ops::LinearChainCRFOpKernel); +REGISTER_OP_GPU_KERNEL( + linear_chain_crf_grad, + ops::LinearChainCRFGradOpKernel, + ops::LinearChainCRFGradOpKernel); diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h new file mode 100644 index 0000000000000000000000000000000000000000..56fb0c9102bee6e2fefd1180ef20237891573f70 --- /dev/null +++ b/paddle/operators/linear_chain_crf_op.h @@ -0,0 +1,543 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +static inline T NormalizeL1(T* x, size_t len) { + T sum = 0.; + for (size_t i = 0; i < len; ++i) sum += x[i]; + // (This comment is from the old LinearChainCRFLayer.) + // Right now, we just bet that sum won't be zero. If this really happens, we + // will figure out what should be done then. + PADDLE_ENFORCE(sum, + "The unnormalized probabilities of all possible unfinished " + "sequences must be greater than 0."); + T s = 1. / sum; + for (size_t i = 0; i < len; ++i) x[i] *= s; + return sum; +} + +template +struct ScalarMul { + explicit ScalarMul(const T& scalar) : scalar(scalar) {} + T operator()(const T& val) const { return val * scalar; } + + T scalar; +}; + +using framework::LoDTensor; +using framework::LoD; +using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class LinearChainCRFOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // TODO(caoying) The checks related to LoD information should be + // moved into InferShape once after the InferShape is refactored. + PADDLE_ENFORCE_EQ(ctx.Input("Emission")->NumLevels(), 1UL, + "The Input(Emission) should be a sequence."); + PADDLE_ENFORCE_EQ(ctx.Input("Label")->NumLevels(), 1UL, + "The Input(Label) should be a sequence."); + auto in_lod = ctx.Input("Label")->lod(); + PADDLE_ENFORCE(in_lod.size(), "Input(Label) must be a sequence."); + const size_t level = 0; + const size_t seq_num = in_lod[level].size() - 1; + + // These local variables hold the inputs and outputs, garanteeing them on + // CPU memory, to provide a consistent reference. + // TODO(caoying) Fix this by moving all these local variables into the + // class's data members once we can profile the whole training process. + LoDTensor* emission_weights = nullptr; + LoDTensor emission_weight_tensor; + Tensor* transition_weights = nullptr; + Tensor transition_weight_tensor; + LoDTensor* label = nullptr; + LoDTensor label_tensor; + + Tensor* emission_exps = nullptr; + Tensor emission_exps_tensor; + Tensor* transition_exps = nullptr; + Tensor transition_exps_tensor; + Tensor* alpha = nullptr; + Tensor alpha_tensor; + Tensor* ll = nullptr; + Tensor ll_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + emission_weights = &emission_weight_tensor; + transition_weights = &transition_weight_tensor; + label = &label_tensor; + + CopyInputsToCpuMemory( + ctx.device_context(), *ctx.Input("Emission"), + *ctx.Input("Transition"), *ctx.Input("Label"), + emission_weights, transition_weights, label); + + emission_exps = &emission_exps_tensor; + emission_exps->Resize(emission_weights->dims()); + + transition_exps = &transition_exps_tensor; + transition_exps->Resize(transition_weights->dims()); + + alpha = &alpha_tensor; + alpha->Resize(ctx.Output("Alpha")->dims()); + + ll = &ll_tensor; + } else { + emission_weights = + const_cast(ctx.Input("Emission")); + transition_weights = const_cast(ctx.Input("Transition")); + label = const_cast(ctx.Input("Label")); + + emission_exps = ctx.Output("EmissionExps"); + transition_exps = ctx.Output("TransitionExps"); + alpha = ctx.Output("Alpha"); + ll = ctx.Output("LogLikelihood"); + } + + // Because the computation codes only runs on CPU, here the memory for all + // the outputs is FIXED to be allocated on the CPU memory. + emission_exps->mutable_data(platform::CPUPlace()); + transition_exps->mutable_data(platform::CPUPlace()); + alpha->mutable_data(platform::CPUPlace()); + + // Resize the output tensor to its correct dimension. + ll->Resize({static_cast(seq_num), 1}); + ll->mutable_data(platform::CPUPlace()); + + // Now, all the inputs and outputs should be on the CPU memory. + auto emission_dims = emission_weights->dims(); + const size_t batch_size = emission_dims[0]; + const size_t tag_num = emission_dims[1]; + + Tensor emission_row_max; + emission_row_max.mutable_data( + framework::make_ddim({static_cast(batch_size), 1}), + platform::CPUPlace()); + + auto place = ctx.GetEigenDevice(); + auto x = EigenMatrix::From(*emission_weights); + auto x_row_max = EigenMatrix::From(emission_row_max); + x_row_max.device(place) = + x.maximum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(int(batch_size), 1)); + + auto x_exps = EigenMatrix::From(*emission_exps); + x_exps.device(place) = + (x - x_row_max.broadcast(Eigen::DSizes(1, tag_num))).exp(); + + auto w = EigenMatrix::From(*transition_weights); + auto w_exps = EigenMatrix::From(*transition_exps); + w_exps.device(place) = w.exp(); + + T* log_likelihood = ll->data(); + for (size_t i = 0; i < seq_num; ++i) { + int start_pos = static_cast(in_lod[level][i]); + int end_pos = static_cast(in_lod[level][i + 1]); + if (end_pos == start_pos) { + // If an empty input sequence is given, pad 0 for its cost. + log_likelihood[i] = 0.; + continue; + } + + const Tensor one_seq = emission_weights->Slice(start_pos, end_pos); + Tensor one_seq_row_max = emission_row_max.Slice(start_pos, end_pos); + Tensor one_seq_exps = emission_exps->Slice(start_pos, end_pos); + const Tensor one_seq_label = label->Slice(start_pos, end_pos); + Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos); + + log_likelihood[i] = ForwardOneSequence( + one_seq, one_seq_row_max, one_seq_exps, *transition_weights, + *transition_exps, one_seq_label, &one_seq_alpha); + } + + if (platform::is_gpu_place(ctx.GetPlace())) { + CopyOutputsToGpuMemory( + ctx.device_context(), *emission_exps, *transition_exps, *alpha, *ll, + ctx.Output("EmissionExps"), + ctx.Output("TransitionExps"), ctx.Output("Alpha"), + ctx.Output("LogLikelihood")); + } + }; + + private: + void CopyInputsToCpuMemory(const platform::DeviceContext& ctx, + const LoDTensor& emission_weights_src, + const Tensor& transition_weights_src, + const LoDTensor& label_src, + LoDTensor* emission_weights_dst, + Tensor* transition_weights_dst, + LoDTensor* label_dst) const { + // Copy the inputs from GPU memory to CPU memory if this operators runs on + // GPU device. + auto copyLoDTensor = [](const platform::DeviceContext& ctx, + const LoDTensor& src, LoDTensor* dst) { + dst->mutable_data(src.dims(), platform::CPUPlace()); + dst->CopyFrom(src, platform::CPUPlace(), ctx); + }; + + copyLoDTensor(ctx, emission_weights_src, emission_weights_dst); + copyLoDTensor(ctx, label_src, label_dst); + + transition_weights_dst->mutable_data(transition_weights_src.dims(), + platform::CPUPlace()); + transition_weights_dst->CopyFrom(transition_weights_src, + platform::CPUPlace(), ctx); + } + + void CopyOutputsToGpuMemory(const platform::DeviceContext& ctx, + const Tensor& emission_exps_src, + const Tensor& transition_exps_src, + const Tensor& alpha_src, const Tensor& ll_src, + Tensor* emission_exps_dst, + Tensor* transition_exps_dst, Tensor* alpha_dst, + Tensor* ll_dst) const { + // Copy the forward results from CPU memory to GPU memory if this + // operators runs on GPU device. + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src, + Tensor* dst) { + dst->mutable_data(platform::GPUPlace()); + dst->CopyFrom(src, platform::GPUPlace(), ctx); + }; + copyTensor(ctx, emission_exps_src, emission_exps_dst); + copyTensor(ctx, transition_exps_src, transition_exps_dst); + copyTensor(ctx, alpha_src, alpha_dst); + copyTensor(ctx, ll_src, ll_dst); + } + + T ForwardOneSequence(const Tensor& emission, const Tensor& emission_row_max, + const Tensor& emission_exps, const Tensor& trans_weights, + const Tensor& trans_weight_exps, const Tensor& label, + Tensor* alpha) const { + const T* x = emission.data(); + const T* x_row_max = emission_row_max.data(); + const T* x_exps = emission_exps.data(); + const T* w = trans_weights.data(); + const T* w_exps = trans_weight_exps.data(); + T* alpha_value = alpha->data(); + + auto x_dims = emission.dims(); + const size_t seq_length = x_dims[0]; + const size_t tag_num = x_dims[1]; + // The 1st row of w are transition weights for start mask. + // The 2nd row of w are transition weights for end mask. + // Transition weights between other tags begin from the 3rd row of w. + const size_t state_trans_base_idx = 2; + + for (size_t i = 0; i < tag_num; ++i) { + alpha_value[i] = w_exps[i] * x_exps[i]; + } + T ll = -x_row_max[0] - std::log(NormalizeL1(alpha_value, tag_num)); + + for (size_t k = 1; k < seq_length; ++k) { + for (size_t i = 0; i < tag_num; ++i) { + T sum = 0.; + for (size_t j = 0; j < tag_num; ++j) { + sum += alpha_value[(k - 1) * tag_num + j] * // (*) + w_exps[(j + state_trans_base_idx) * tag_num + i]; + } + alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum; + } + // NormalizeL1 is to avoid underflow or overflow at (*). + ll -= x_row_max[k] + + std::log(NormalizeL1(alpha_value + k * tag_num, tag_num)); + } + T sum = 0.; + for (size_t i = 0; i < tag_num; ++i) { + sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i]; + } + ll -= std::log(sum); + // Now ll is equal to -log(Z). + + const int* lbl = label.data(); + PADDLE_ENFORCE_LT( + *std::max_element(lbl, lbl + seq_length), tag_num, + "An invalid tag label that execesses the largest tag number."); + + // Calculate the nominator part, which depends on the label sequence. + ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] + + w[tag_num + lbl[seq_length - 1]] /*end transition*/; + for (size_t k = 1; k < seq_length; ++k) { + ll += x[k * tag_num + lbl[k]] + + w[(lbl[k - 1] + state_trans_base_idx) * tag_num + lbl[k]]; + } + return -ll; + } +}; + +template +class LinearChainCRFGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const size_t level = 0; // currently, only support sequence. + auto lod = ctx.Input("Label")->lod(); + PADDLE_ENFORCE(lod.size(), "Input(Label) must be a sequence."); + + // These local variables hold the inputs and outputs, garanteeing them on + // CPU memory, to provide a consistent reference. + // TODO(caoying) Fix this by moving all these local variables into the + // class's data members once we can profile the training process, or + // implementing a real GPU kernel for CRF. + Tensor* label = nullptr; + Tensor label_tensor; + Tensor* emission_exps = nullptr; + Tensor emission_exps_tensor; + Tensor* transition_exps = nullptr; + Tensor transition_exps_tensor; + Tensor* alpha = nullptr; + Tensor alpha_tensor; + Tensor ll_grad_tensor; + T* ll_grad = nullptr; + + Tensor* emission_grad = nullptr; + Tensor emission_grad_tensor; + Tensor* transition_grad = nullptr; + Tensor transition_grad_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + label = &label_tensor; + emission_exps = &emission_exps_tensor; + transition_exps = &transition_exps_tensor; + alpha = &alpha_tensor; + CopyInputsToCpuMemory( + ctx.device_context(), *ctx.Input("Label"), + *ctx.Input("EmissionExps"), + *ctx.Input("TransitionExps"), *ctx.Input("Alpha"), + *ctx.Input(framework::GradVarName("LogLikelihood")), label, + emission_exps, transition_exps, alpha, &ll_grad_tensor); + ll_grad = ll_grad_tensor.data(); + + if (ctx.Output(framework::GradVarName("Emission"))) { + emission_grad = &emission_grad_tensor; + emission_grad->Resize(emission_exps->dims()); + } + + if (ctx.Output(framework::GradVarName("Transition"))) { + transition_grad = &transition_grad_tensor; + transition_grad->Resize(transition_exps->dims()); + } + } else { + label = const_cast(ctx.Input("Label")); + emission_exps = const_cast(ctx.Input("EmissionExps")); + transition_exps = + const_cast(ctx.Input("TransitionExps")); + alpha = const_cast(ctx.Input("Alpha")); + ll_grad = const_cast( + ctx.Input(framework::GradVarName("LogLikelihood"))) + ->data(); + + emission_grad = ctx.Output(framework::GradVarName("Emission")); + transition_grad = + ctx.Output(framework::GradVarName("Transition")); + } + + // TODO(caoying) Fix this constraint. When the Input(Emission) is from the + // data reader operator, it can have no gradients. + PADDLE_ENFORCE(emission_grad, "Output(Emission@Grad) should not be null."); + emission_grad->mutable_data(platform::CPUPlace()); + if (transition_grad) { + transition_grad->mutable_data(platform::CPUPlace()); + math::SetConstant()(ctx.device_context(), + transition_grad, 0.); + } + // Now, all the inputs and outputs should be on the CPU memory. + + auto emission_dims = emission_exps->dims(); + // Beta is the memo table used in dynamic programming to calculate the + // backwark vectors. For a backward vector i (the i-th row of beta), it + // captures the unnormalized probabilities of partial sequences starting + // at position i. + Tensor beta; + beta.mutable_data(emission_dims, platform::CPUPlace()); + + for (size_t i = 0; i < lod[level].size() - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + if (end_pos == start_pos) continue; + + const Tensor one_seq_emission_exps = + emission_exps->Slice(start_pos, end_pos); + const Tensor one_seq_label = label->Slice(start_pos, end_pos); + const Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos); + Tensor one_seq_beta = beta.Slice(start_pos, end_pos); + Tensor one_seq_emission_grad = emission_grad->Slice(start_pos, end_pos); + + BackwardOneSequence(ctx.device_context(), ll_grad[i], + one_seq_emission_exps, *transition_exps, + one_seq_alpha, one_seq_label, &one_seq_beta, + transition_grad, &one_seq_emission_grad); + } + + if (platform::is_gpu_place(ctx.GetPlace())) { + CopyOutputsToGpuMemory( + ctx.device_context(), emission_grad, transition_grad, + ctx.Output(framework::GradVarName("Emission")), + ctx.Output(framework::GradVarName("Transition"))); + } + }; + + private: + void CopyInputsToCpuMemory(const platform::DeviceContext& ctx, + const LoDTensor& label_src, + const Tensor& emission_exps_src, + const Tensor& transition_exps_src, + const Tensor& alpha_src, const Tensor& ll_grad_src, + Tensor* label_dst, Tensor* emission_exps_dst, + Tensor* transition_exps_dst, Tensor* alpha_dst, + Tensor* ll_grad_dst) const { + // Copy the inputs from GPU memory to CPU memory when this operators runs on + // GPU device. + label_dst->mutable_data(label_src.dims(), platform::CPUPlace()); + label_dst->CopyFrom(label_src, platform::CPUPlace(), ctx); + + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src, + Tensor* dst) { + dst->mutable_data(src.dims(), platform::CPUPlace()); + dst->CopyFrom(src, platform::CPUPlace(), ctx); + }; + copyTensor(ctx, emission_exps_src, emission_exps_dst); + copyTensor(ctx, transition_exps_src, transition_exps_dst); + copyTensor(ctx, alpha_src, alpha_dst); + copyTensor(ctx, ll_grad_src, ll_grad_dst); + } + + void CopyOutputsToGpuMemory(const platform::DeviceContext& ctx, + const Tensor* emission_grad_src, + const Tensor* transition_grad_src, + Tensor* emission_grad_dst, + Tensor* transition_grad_dst) const { + // Copy the backward results from CPU memory to GPU + // memory if this operators runs on GPU device. + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor* src, + Tensor* dst) { + if (src && dst) { + dst->mutable_data(platform::GPUPlace()); + dst->CopyFrom(*src, platform::GPUPlace(), ctx); + } + }; + copyTensor(ctx, emission_grad_src, emission_grad_dst); + copyTensor(ctx, transition_grad_src, transition_grad_dst); + } + + void BackwardOneSequence(const platform::DeviceContext& ctx, const T ll_grad, + const Tensor& emission_exps, + const Tensor& transition_exps, const Tensor& alpha, + const Tensor& label, Tensor* beta, + Tensor* transition_grad, + Tensor* emission_grad) const { + const T* w_exps = transition_exps.data(); + const T* x_exps = emission_exps.data(); + const int* label_value = label.data(); + T* beta_value = beta->data(); + + auto x_dims = emission_exps.dims(); + const size_t seq_length = x_dims[0]; + const size_t tag_num = x_dims[1]; + const size_t state_trans_base_idx = 2; + + // Calculate the backward vectors: beta. + // First, calculate the initialition state. + for (size_t i = 0; i < tag_num; ++i) { + beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i]; + } + NormalizeL1(beta_value + (seq_length - 1) * tag_num, tag_num); + for (int k = static_cast(seq_length) - 2; k >= 0; --k) { + for (size_t i = 0; i < tag_num; ++i) { + T sum = 0.; + for (size_t j = 0; j < tag_num; ++j) { + sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**) + x_exps[(k + 1) * tag_num + j] * + beta_value[(k + 1) * tag_num + j]; + } + beta_value[k * tag_num + i] = sum; + } + // NormalizeL1 is to avoid underflow or overflow at (**). + NormalizeL1(beta_value + k * tag_num, tag_num); + } + + auto x_grad_mat = EigenMatrix::From(*emission_grad); + auto alpha_mat = EigenMatrix::From(alpha); + auto beta_mat = EigenMatrix::From(*beta); + + auto* place = ctx.GetEigenDevice(); + auto prob = alpha_mat * beta_mat; + auto row_sum = prob.sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(seq_length, 1)) + .broadcast(Eigen::DSizes(1, tag_num)); + x_grad_mat.device(*place) = + (prob / row_sum).unaryExpr(ScalarMul(ll_grad)); + + for (size_t k = 0; k < seq_length; ++k) { + x_grad_mat(k, label_value[k]) -= static_cast(ll_grad); + } + + if (transition_grad) { + T* trans_grad = transition_grad->data(); + for (size_t k = 0; k < tag_num; ++k) { + // Do not multiply by the output gradient here, because x_grad_mat has + // alrealy done this. + trans_grad[k] += x_grad_mat(/*from start state*/ 0, k); + trans_grad[tag_num + k] += + x_grad_mat(/*to end state*/ seq_length - 1, k); + } + + auto x_exps_mat = EigenMatrix::From(emission_exps); + + // TODO(caoying): Fix this to avoid using this local variable if we can + // profile the training process. + Tensor tmp; + tmp.mutable_data(beta->dims(), platform::CPUPlace()); + auto tmp_mat = EigenMatrix::From(tmp); + auto prob = beta_mat * x_exps_mat; + auto row_sum = prob.sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(seq_length, 1)) + .broadcast(Eigen::DSizes(1, tag_num)); + tmp_mat.device(*place) = prob / row_sum; + + for (size_t k = 1; k < seq_length; ++k) { + T sum = 0.; + for (size_t i = 0; i < tag_num; ++i) { + for (size_t j = 0; j < tag_num; ++j) { + sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**) + alpha_mat(k - 1, i) * tmp_mat(k, j); + } + } + sum = 1. / sum; + for (size_t i = 0; i < tag_num; ++i) { + for (size_t j = 0; j < tag_num; ++j) { + trans_grad[(i + state_trans_base_idx) * tag_num + j] += + sum * w_exps[(i + state_trans_base_idx) * tag_num + j] * + alpha_mat(k - 1, i) * tmp_mat(k, j) * ll_grad; + } + } + trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num + + label_value[k]] -= static_cast(ll_grad); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 942fbb42df8bb90b86bd097832a15b320a857750..50497da1b70d39d2638240dd91035c9181124af9 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -32,9 +32,9 @@ class SoftmaxWithCrossEntropyOpMaker AddInput("Label", "(Tensor, default: Tensor), The ground truth which is a 2-D " "tensor. " - "If softLable is set to 0, Label is a Tensor with shape [N x " - "1]. " - "If softLable is set to 1, Label is a Tensor " + "If softLabel is set to false, Label is a Tensor with shape " + "[N x 1]." + "If softLabel is set to true, Label is a Tensor " "with shape [N x K]."); AddOutput( "Softmax", @@ -60,19 +60,23 @@ Because this operators performs a softmax on logits internally, it expects unscaled logits. Please do not call this op with the output of softmax operator, which will produce incorrect results. -This operators expects mutually exclusive hard labels, each sample in a batch -is in exactly one class with probabilities 1. Each sample in the batch with one -and only one label. +When the attribute softLabel is set false, this operators expects mutually +exclusive hard labels, each sample in a batch is in exactly one class with +probabilities 1. Each sample in the batch with one and only one label. Equation: 1) hard label (one-hot label) -Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K +Loss_j = \f$ -\text{Logit}_{Label_j} + +\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), +j = 1, ..., K $\f 2) soft label (a distribution over all classes) -Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K +Loss_j = \f$ -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i - +\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), +j = 1,...,K $\f )DOC"); } diff --git a/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py b/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6f06a66c825b37ee91214efc0a29a58f0b9057f9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py @@ -0,0 +1,142 @@ +import unittest +import random +import numpy as np + +from op_test import OpTest + + +class LinearChainCrfForward(object): + def __init__(self, seq_start_positions, emission_weights, emission_row_max, + emission_exps, transition_weights, transition_exps, labels): + self.tag_num = emission_weights.shape[1] + self.seq_num = len(seq_start_positions) - 1 + + self.seq_start_positions = seq_start_positions + self.labels = labels + self.x = emission_weights + + self.x_row_max = emission_row_max + self.x_exps = emission_exps + + # unnormalized logits of the transition weights for the start mark. + self.a = transition_weights[0, :] + self.a_exps = transition_exps[0, :] + # unnormalized logits of the transition weights for the end mark. + self.b = transition_weights[1, :] + self.b_exps = transition_exps[1, :] + # unnormalized logits of the transition weights for all the other tags. + self.w = transition_weights[2:, :] + self.w_exps = transition_exps[2:, :] + + # The output of linear chain crf operator. + # alpha is a memo table in dynamic programming to caculate + # nomalization factor. + self.alpha = np.zeros( + (seq_start_positions[-1], self.tag_num), dtype="float64") + self.log_likelihood = np.zeros((self.seq_num, 1)) + + def _l1_norm(self, x): + s = np.sum(x) + x /= s + return s + + def _forward_a_sequence(self, x, x_row_max, x_exps, label, alpha): + seq_len = x_row_max.shape[0] + log_likelihood = 0. + + for i in range(self.tag_num): + alpha[0, i] = self.a_exps[i] * x_exps[0, i] + log_likelihood = -x_row_max[0] - np.log(self._l1_norm(alpha[0, :])) + + # calculate the unnormalized logits of the normalization factor. + for k in range(1, seq_len): + for i in range(self.tag_num): + s = 0. + for j in range(self.tag_num): + s += alpha[k - 1, j] * self.w_exps[j, i] + alpha[k, i] = x_exps[k, i] * s + log_likelihood -= x_row_max[k] + np.log(self._l1_norm(alpha[k, :])) + s = 0. + for i in range(self.tag_num): + s += alpha[-1, i] * self.b_exps[i] + log_likelihood -= np.log(s) + + # calculate the nominator part. + log_likelihood += ( + self.a[label[0]] + x[0, label[0]] + self.b[label[-1]]) + + for k in range(1, seq_len): + log_likelihood += (x[k, label[k]] + self.w[label[k - 1], label[k]]) + return -log_likelihood + + def crf_forward_compute(self): + for i in range(self.seq_num): + start = self.seq_start_positions[i] + end = self.seq_start_positions[i + 1] + + self.log_likelihood[i] = self._forward_a_sequence( + self.x[start:end, :], self.x_row_max[start:end, :], + self.x_exps[start:end, :], self.labels[start:end, :], + self.alpha[start:end, :]) + return self.alpha, self.log_likelihood + + +class TestLinearChainCrfOp(OpTest): + def set_test_data(self): + # TODO(caoying) Fix the unittest by: add the boundary cases when + # sequence lengths are 1, 2, and 3. + + SEQ_NUM = 3 + TAG_NUM = 17 + MAX_SEQ_LEN = 5 + + # the linear_chain_crf operator only supports sequence (LoD level = 1) + lod = [[0]] + for i in range(SEQ_NUM): + lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN)) + emission = np.random.uniform(-1, 1, + [lod[-1][-1], TAG_NUM]).astype("float64") + emission_row_max = np.amax(emission, axis=1, keepdims=True) + emission_exps = np.exp(emission - emission_row_max) + + transition = np.random.uniform(-0.5, 0.5, + [TAG_NUM + 2, TAG_NUM]).astype("float64") + transition_exps = np.exp(transition) + + labels = np.random.randint( + low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32") + + self.inputs = { + "Emission": (emission, lod), + "Transition": transition, + "Label": (labels, lod) + } + crf = LinearChainCrfForward(lod[0], emission, emission_row_max, + emission_exps, transition, transition_exps, + labels) + alpha, log_likelihood = crf.crf_forward_compute() + + self.outputs = { + "Alpha": alpha, + "EmissionExps": emission_exps, + "TransitionExps": transition_exps, + "LogLikelihood": log_likelihood + } + + def setUp(self): + self.op_type = "linear_chain_crf" + self.set_test_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["Emission", "Transition"], "LogLikelihood") + + def test_check_grad_ignore_transition(self): + self.check_grad( + ["Emission"], "LogLikelihood", no_grad_set=set("Transition")) + + +if __name__ == "__main__": + unittest.main()