提交 bce4f7d6 编写于 作者: C caoying03

follow comments.

上级 4c630869
...@@ -228,8 +228,9 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { ...@@ -228,8 +228,9 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE_GE(begin_idx, 0, PADDLE_ENFORCE_GE(begin_idx, 0,
"The start row index must be greater than 0."); "The start row index must be greater than 0.");
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound."); PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound.");
PADDLE_ENFORCE_LT(begin_idx, end_idx, PADDLE_ENFORCE_LT(
"The start row index must be less than the end row index."); begin_idx, end_idx,
"The start row index must be smaller than the end row index.");
if (dims_[0] == 1) { if (dims_[0] == 1) {
return *this; return *this;
......
...@@ -26,9 +26,10 @@ T NormalizeL1(T* x, size_t len) { ...@@ -26,9 +26,10 @@ T NormalizeL1(T* x, size_t len) {
// Right now, we just bet that sum won't be zero. If this really happens, we // Right now, we just bet that sum won't be zero. If this really happens, we
// will figure out what should be done then. // will figure out what should be done then.
PADDLE_ENFORCE(sum, PADDLE_ENFORCE(sum,
"The unnormalized probabilites of all possible unfinished " "The unnormalized probabilities of all possible unfinished "
"sequences must be greater than 0."); "sequences must be greater than 0.");
for (size_t i = 0; i < len; ++i) x[i] /= sum; T s = 1. / sum;
for (size_t i = 0; i < len; ++i) x[i] *= s;
return sum; return sum;
} }
} // namespace } // namespace
...@@ -36,9 +37,9 @@ T NormalizeL1(T* x, size_t len) { ...@@ -36,9 +37,9 @@ T NormalizeL1(T* x, size_t len) {
using framework::LoDTensor; using framework::LoDTensor;
using framework::LoD; using framework::LoD;
class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LinearChainCrfOpMaker(framework::OpProto* proto, LinearChainCRFOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
...@@ -51,11 +52,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,11 +52,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput( AddInput(
"Transition", "Transition",
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. " "(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
"The learnable parameter for linear_chain_crf operator. " "The learnable parameter for the linear_chain_crf operator. "
"See more details in the operator's comments."); "See more details in the operator's comments.");
AddInput( AddInput(
"Label", "Label",
"(LoDTensor, default: LoDTensor<int>). The ground truth which is a 2-D " "(LoDTensor, default: LoDTensor<int>). The groundtruth which is a 2-D "
"LoDTensor with shape [N x 1], where N is the total element number in " "LoDTensor with shape [N x 1], where N is the total element number in "
"a mini-batch."); "a mini-batch.");
AddOutput( AddOutput(
...@@ -82,14 +83,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -82,14 +83,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
.AsIntermediate(); .AsIntermediate();
AddOutput( AddOutput(
"LogLikelihood", "LogLikelihood",
"(Tensor, default: Tensor<float>). The logarithm of the " "(Tensor, default: Tensor<float>). The logarithm of the conditional "
"conditional "
"likelihood of each training sample in a mini-batch. This is a 2-D " "likelihood of each training sample in a mini-batch. This is a 2-D "
"tensor with shape [S x 1], where S is the sequence number in a " "tensor with shape [S x 1], where S is the sequence number in a "
"mini-batch. " "mini-batch. Note: S is equal to the sequence number in a mini-batch. "
"Note: S is equal to the sequence number in a mini-batch. The " "The output is no longer a LoDTensor.");
"output "
"is no longer a LoDTensor.");
AddComment(R"DOC( AddComment(R"DOC(
Conditional Random Field defines an undirected probabilistic graph with nodes Conditional Random Field defines an undirected probabilistic graph with nodes
denoting random variables and edges denoting dependencies between these denoting random variables and edges denoting dependencies between these
...@@ -100,11 +98,11 @@ variables. CRF learns the conditional probability \f$P(Y|X)\f$, where ...@@ -100,11 +98,11 @@ variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
Linear chain CRF is a special case of CRF that is useful for sequence labeling Linear chain CRF is a special case of CRF that is useful for sequence labeling
task. Sequence labeling tasks do not assume a lot of conditional task. Sequence labeling tasks do not assume a lot of conditional
independences among inputs. They only concern about the input and the output independences among inputs. They only concern about the input and the output
being linear sequences. Thus, the graph model of CRF is a simple chain or being linear sequences. Thus, the graph model of such a CRF is a simple chain
a line, which results in a linear chain CRF. or a line, which results in the linear chain CRF.
This operator implements the Forward-Backward algorithm for linear chain CRF. This operator implements the Forward-Backward algorithm for the linear chain
Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
Equation: Equation:
...@@ -144,7 +142,7 @@ nonlinear activation. ...@@ -144,7 +142,7 @@ nonlinear activation.
} }
}; };
class LinearChainCrfOp : public framework::OperatorWithKernel { class LinearChainCRFOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -211,7 +209,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel { ...@@ -211,7 +209,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
}; };
template <typename T> template <typename T>
class LinearChainCrfOpKernel<platform::CPUPlace, T> class LinearChainCRFOpKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -262,11 +260,11 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -262,11 +260,11 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
w_exps.device(place) = w.exp(); w_exps.device(place) = w.exp();
auto* alpha = ctx.Output<LoDTensor>("Alpha"); auto* alpha = ctx.Output<LoDTensor>("Alpha");
alpha->mutable_data<T>(ctx.GetPlace()); alpha->mutable_data<T>(platform::CPUPlace());
auto* ll = ctx.Output<LoDTensor>("LogLikelihood"); auto* ll = ctx.Output<LoDTensor>("LogLikelihood");
// resize the output tensor to the correct dimension. // resize the output tensor to the correct dimension.
ll->Resize({static_cast<int>(seq_num), 1}); ll->Resize({static_cast<int>(seq_num), 1});
T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace()); T* log_likelihood = ll->mutable_data<T>(platform::CPUPlace());
for (size_t i = 0; i < seq_num; ++i) { for (size_t i = 0; i < seq_num; ++i) {
int start_pos = static_cast<int>(in_lod[level][i]); int start_pos = static_cast<int>(in_lod[level][i]);
int end_pos = static_cast<int>(in_lod[level][i + 1]); int end_pos = static_cast<int>(in_lod[level][i + 1]);
...@@ -322,6 +320,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -322,6 +320,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
} }
alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum; alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
} }
// NormalizeL1 is to avoid underflow or overflow at (*).
ll -= x_row_max[k] + ll -= x_row_max[k] +
std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num)); std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
} }
...@@ -330,6 +329,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -330,6 +329,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i]; sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
} }
ll -= std::log(sum); ll -= std::log(sum);
// Now ll is equal to -log(Z).
const int* lbl = label->data<int>(); const int* lbl = label->data<int>();
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
...@@ -347,7 +347,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T> ...@@ -347,7 +347,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
} }
}; };
class LinearChainCrfGradOp : public framework::OperatorWithKernel { class LinearChainCRFGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -407,11 +407,11 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel { ...@@ -407,11 +407,11 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
}; };
template <typename T> template <typename T>
class LinearChainCrfGradOpKernel<platform::CPUPlace, T> class LinearChainCRFGradOpKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(platform::CPUPlace()),
"This kernel only runs on CPU."); "This kernel only runs on CPU.");
auto* label = ctx.Input<LoDTensor>("Label"); auto* label = ctx.Input<LoDTensor>("Label");
auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps"); auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps");
...@@ -493,6 +493,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T> ...@@ -493,6 +493,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
} }
beta_value[k * tag_num + i] = sum; beta_value[k * tag_num + i] = sum;
} }
// NormalizeL1 is to avoid underflow or overflow at (**).
NormalizeL1<T>(beta_value + k * tag_num, tag_num); NormalizeL1<T>(beta_value + k * tag_num, tag_num);
} }
...@@ -534,7 +535,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T> ...@@ -534,7 +535,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
T sum = 0.; T sum = 0.;
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
for (size_t j = 0; j < tag_num; ++j) { for (size_t j = 0; j < tag_num; ++j) {
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**)
alpha_mat(k - 1, i) * tmp_mat(k, j); alpha_mat(k - 1, i) * tmp_mat(k, j);
} }
} }
...@@ -557,11 +558,11 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T> ...@@ -557,11 +558,11 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker, REGISTER_OP(linear_chain_crf, ops::LinearChainCRFOp, ops::LinearChainCRFOpMaker,
linear_chain_crf_grad, ops::LinearChainCrfGradOp); linear_chain_crf_grad, ops::LinearChainCRFGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
linear_chain_crf, linear_chain_crf,
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>); ops::LinearChainCRFOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
linear_chain_crf_grad, linear_chain_crf_grad,
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>); ops::LinearChainCRFGradOpKernel<paddle::platform::CPUPlace, float>);
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class LinearChainCrfOpKernel : public framework::OpKernel<T> { class LinearChainCRFOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
...@@ -37,7 +37,7 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,7 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> { class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册