From bce4f7d6eba070e4465ad52d65524e57d3745bae Mon Sep 17 00:00:00 2001 From: caoying03 Date: Thu, 26 Oct 2017 17:41:01 +0800 Subject: [PATCH] follow comments. --- paddle/framework/tensor_impl.h | 5 ++- paddle/operators/linear_chain_crf_op.cc | 57 +++++++++++++------------ paddle/operators/linear_chain_crf_op.h | 4 +- 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 9090ff9532e..4097f92e021 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -228,8 +228,9 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { PADDLE_ENFORCE_GE(begin_idx, 0, "The start row index must be greater than 0."); PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound."); - PADDLE_ENFORCE_LT(begin_idx, end_idx, - "The start row index must be less than the end row index."); + PADDLE_ENFORCE_LT( + begin_idx, end_idx, + "The start row index must be smaller than the end row index."); if (dims_[0] == 1) { return *this; diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc index d13d4829d91..0f21ee7264b 100644 --- a/paddle/operators/linear_chain_crf_op.cc +++ b/paddle/operators/linear_chain_crf_op.cc @@ -26,9 +26,10 @@ T NormalizeL1(T* x, size_t len) { // Right now, we just bet that sum won't be zero. If this really happens, we // will figure out what should be done then. PADDLE_ENFORCE(sum, - "The unnormalized probabilites of all possible unfinished " + "The unnormalized probabilities of all possible unfinished " "sequences must be greater than 0."); - for (size_t i = 0; i < len; ++i) x[i] /= sum; + T s = 1. / sum; + for (size_t i = 0; i < len; ++i) x[i] *= s; return sum; } } // namespace @@ -36,9 +37,9 @@ T NormalizeL1(T* x, size_t len) { using framework::LoDTensor; using framework::LoD; -class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { +class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { public: - LinearChainCrfOpMaker(framework::OpProto* proto, + LinearChainCRFOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( @@ -51,11 +52,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { AddInput( "Transition", "(Tensor, default: Tensor). A Tensor with shape [(D + 2) x D]. " - "The learnable parameter for linear_chain_crf operator. " + "The learnable parameter for the linear_chain_crf operator. " "See more details in the operator's comments."); AddInput( "Label", - "(LoDTensor, default: LoDTensor). The ground truth which is a 2-D " + "(LoDTensor, default: LoDTensor). The groundtruth which is a 2-D " "LoDTensor with shape [N x 1], where N is the total element number in " "a mini-batch."); AddOutput( @@ -82,14 +83,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { .AsIntermediate(); AddOutput( "LogLikelihood", - "(Tensor, default: Tensor). The logarithm of the " - "conditional " + "(Tensor, default: Tensor). The logarithm of the conditional " "likelihood of each training sample in a mini-batch. This is a 2-D " "tensor with shape [S x 1], where S is the sequence number in a " - "mini-batch. " - "Note: S is equal to the sequence number in a mini-batch. The " - "output " - "is no longer a LoDTensor."); + "mini-batch. Note: S is equal to the sequence number in a mini-batch. " + "The output is no longer a LoDTensor."); AddComment(R"DOC( Conditional Random Field defines an undirected probabilistic graph with nodes denoting random variables and edges denoting dependencies between these @@ -100,11 +98,11 @@ variables. CRF learns the conditional probability \f$P(Y|X)\f$, where Linear chain CRF is a special case of CRF that is useful for sequence labeling task. Sequence labeling tasks do not assume a lot of conditional independences among inputs. They only concern about the input and the output -being linear sequences. Thus, the graph model of CRF is a simple chain or -a line, which results in a linear chain CRF. +being linear sequences. Thus, the graph model of such a CRF is a simple chain +or a line, which results in the linear chain CRF. -This operator implements the Forward-Backward algorithm for linear chain CRF. -Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. +This operator implements the Forward-Backward algorithm for the linear chain +CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. Equation: @@ -144,7 +142,7 @@ nonlinear activation. } }; -class LinearChainCrfOp : public framework::OperatorWithKernel { +class LinearChainCRFOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -211,7 +209,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel { }; template -class LinearChainCrfOpKernel +class LinearChainCRFOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -262,11 +260,11 @@ class LinearChainCrfOpKernel w_exps.device(place) = w.exp(); auto* alpha = ctx.Output("Alpha"); - alpha->mutable_data(ctx.GetPlace()); + alpha->mutable_data(platform::CPUPlace()); auto* ll = ctx.Output("LogLikelihood"); // resize the output tensor to the correct dimension. ll->Resize({static_cast(seq_num), 1}); - T* log_likelihood = ll->mutable_data(ctx.GetPlace()); + T* log_likelihood = ll->mutable_data(platform::CPUPlace()); for (size_t i = 0; i < seq_num; ++i) { int start_pos = static_cast(in_lod[level][i]); int end_pos = static_cast(in_lod[level][i + 1]); @@ -322,6 +320,7 @@ class LinearChainCrfOpKernel } alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum; } + // NormalizeL1 is to avoid underflow or overflow at (*). ll -= x_row_max[k] + std::log(NormalizeL1(alpha_value + k * tag_num, tag_num)); } @@ -330,6 +329,7 @@ class LinearChainCrfOpKernel sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i]; } ll -= std::log(sum); + // Now ll is equal to -log(Z). const int* lbl = label->data(); PADDLE_ENFORCE_LT( @@ -347,7 +347,7 @@ class LinearChainCrfOpKernel } }; -class LinearChainCrfGradOp : public framework::OperatorWithKernel { +class LinearChainCRFGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -407,11 +407,11 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel { }; template -class LinearChainCrfGradOpKernel +class LinearChainCRFGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + PADDLE_ENFORCE(platform::is_cpu_place(platform::CPUPlace()), "This kernel only runs on CPU."); auto* label = ctx.Input("Label"); auto* emission_exps = ctx.Input("EmissionExps"); @@ -493,6 +493,7 @@ class LinearChainCrfGradOpKernel } beta_value[k * tag_num + i] = sum; } + // NormalizeL1 is to avoid underflow or overflow at (**). NormalizeL1(beta_value + k * tag_num, tag_num); } @@ -534,7 +535,7 @@ class LinearChainCrfGradOpKernel T sum = 0.; for (size_t i = 0; i < tag_num; ++i) { for (size_t j = 0; j < tag_num; ++j) { - sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * + sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**) alpha_mat(k - 1, i) * tmp_mat(k, j); } } @@ -557,11 +558,11 @@ class LinearChainCrfGradOpKernel } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker, - linear_chain_crf_grad, ops::LinearChainCrfGradOp); +REGISTER_OP(linear_chain_crf, ops::LinearChainCRFOp, ops::LinearChainCRFOpMaker, + linear_chain_crf_grad, ops::LinearChainCRFGradOp); REGISTER_OP_CPU_KERNEL( linear_chain_crf, - ops::LinearChainCrfOpKernel); + ops::LinearChainCRFOpKernel); REGISTER_OP_CPU_KERNEL( linear_chain_crf_grad, - ops::LinearChainCrfGradOpKernel); + ops::LinearChainCRFGradOpKernel); diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h index f65d268bb62..3175252c660 100644 --- a/paddle/operators/linear_chain_crf_op.h +++ b/paddle/operators/linear_chain_crf_op.h @@ -25,7 +25,7 @@ template ; template -class LinearChainCrfOpKernel : public framework::OpKernel { +class LinearChainCRFOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override; @@ -37,7 +37,7 @@ class LinearChainCrfOpKernel : public framework::OpKernel { }; template -class LinearChainCrfGradOpKernel : public framework::OpKernel { +class LinearChainCRFGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override; -- GitLab